Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ef76f664
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ef76f664
编写于
10月 28, 2021
作者:
L
Liu-xiandong
提交者:
GitHub
10月 28, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rewrite Softmax in Kernel Primitive API, test=develop (#36706)
上级
b151a451
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
191 addition
and
210 deletion
+191
-210
paddle/fluid/operators/softmax_cudnn_op.cu.h
paddle/fluid/operators/softmax_cudnn_op.cu.h
+191
-210
未找到文件。
paddle/fluid/operators/softmax_cudnn_op.cu.h
浏览文件 @
ef76f664
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
...
...
@@ -99,6 +100,97 @@ __device__ __forceinline__ void WarpReduceMax(T* sum) {
}
}
namespace
kps
=
paddle
::
operators
::
kernel_primitives
;
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
ReduceMaxFunctor
{
inline
Ty
initial
()
{
return
-
std
::
numeric_limits
<
Ty
>::
infinity
();
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
max
(
a
,
b
);
}
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
ExpSubFunctor
{
HOSTDEVICE
inline
ExpSubFunctor
()
{
y
=
static_cast
<
Tx
>
(
0.0
f
);
}
HOSTDEVICE
explicit
inline
ExpSubFunctor
(
Tx
y
)
:
y
((
Tx
)(
y
))
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
std
::
exp
(
x
-
y
));
}
private:
Tx
y
;
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
ExpMulFunctor
{
HOSTDEVICE
inline
ExpMulFunctor
()
{
y
=
static_cast
<
Tx
>
(
1.0
f
);
}
HOSTDEVICE
explicit
inline
ExpMulFunctor
(
Tx
y
)
:
y
((
Tx
)(
y
))
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
std
::
exp
(
x
)
*
y
);
}
private:
Tx
y
;
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
UnarySubFunctor
{
HOSTDEVICE
inline
UnarySubFunctor
()
{
y
=
static_cast
<
Tx
>
(
0.0
f
);
}
HOSTDEVICE
explicit
inline
UnarySubFunctor
(
Tx
y
)
:
y
((
Tx
)(
y
))
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
x
-
y
);
}
private:
Tx
y
;
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
UnaryLogFunctor
{
HOSTDEVICE
inline
UnaryLogFunctor
()
{}
HOSTDEVICE
explicit
inline
UnaryLogFunctor
(
int
n
)
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
std
::
log
(
x
));
}
};
template
<
typename
Tx
,
typename
Ty
>
struct
DataTransFunctor
{
HOSTDEVICE
inline
DataTransFunctor
()
{}
HOSTDEVICE
explicit
inline
DataTransFunctor
(
int
n
)
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
x
==
-
std
::
numeric_limits
<
Tx
>::
infinity
()
?
-
std
::
numeric_limits
<
Ty
>::
infinity
()
:
static_cast
<
Ty
>
(
x
);
}
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
UnaryDivFunctor
{
HOSTDEVICE
inline
UnaryDivFunctor
()
{
n_inv
=
static_cast
<
Tx
>
(
1.0
f
);
}
HOSTDEVICE
explicit
inline
UnaryDivFunctor
(
Tx
n
)
:
n_inv
((
Tx
)(
1.0
/
n
))
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
x
*
n_inv
);
}
private:
Tx
n_inv
;
};
/*
Core function of computing softmax forward for axis=-1.
The computation includes
...
...
@@ -117,12 +209,14 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src,
constexpr
int
kDimCeil
=
1
<<
Log2Elements
;
constexpr
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
constexpr
int
kVSize
=
sizeof
(
VecT
)
/
sizeof
(
T
);
constexpr
int
kIterations
=
kDimCeil
/
kWarpSize
;
constexpr
int
kIterationsV
=
(
kIterations
>=
kVSize
)
?
(
kIterations
/
kVSize
)
:
1
;
constexpr
int
kLoops
=
kDimCeil
/
kWarpSize
;
constexpr
int
kLoopsV
=
(
kLoops
>=
kVSize
)
?
(
kLoops
/
kVSize
)
:
1
;
constexpr
int
kBatchSize
=
(
kDimCeil
<=
32
)
?
2
:
1
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
kBatchSize
;
constexpr
int
kStep
=
kBatchSize
*
kLoopsV
*
kVSize
;
constexpr
int
kVItem
=
kLoopsV
*
kVSize
;
constexpr
AccT
kLowInf
=
-
std
::
numeric_limits
<
AccT
>::
infinity
();
using
kMode
=
kps
::
details
::
ReduceMode
;
// max index to read
int
idx_max_v
[
kBatchSize
];
...
...
@@ -133,146 +227,51 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src,
}
// read data from global memory
AccT
srcdata
[
kBatchSize
][
kIterationsV
][
kVSize
];
AccT
srcdata
[
kBatchSize
][
kLoopsV
][
kVSize
];
kps
::
Init
<
AccT
,
kStep
>
(
&
srcdata
[
0
][
0
][
0
],
kLowInf
);
T
src_tmp
[
kBatchSize
][
kLoopsV
][
kVSize
];
kps
::
Init
<
T
,
kStep
>
(
&
src_tmp
[
0
][
0
][
0
],
-
std
::
numeric_limits
<
T
>::
infinity
());
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
// read data
#pragma unroll
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
int
src_idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
if
(
kVSize
==
1
)
{
if
(
src_idx
<
idx_max_v
[
i
])
{
srcdata
[
i
][
it
][
0
]
=
static_cast
<
AccT
>
(
src
[(
first_batch
+
i
)
*
stride
+
src_idx
]);
}
else
{
srcdata
[
i
][
it
][
0
]
=
-
std
::
numeric_limits
<
AccT
>::
infinity
();
}
}
else
{
const
VecT
*
src_v
=
reinterpret_cast
<
const
VecT
*>
(
&
src
[(
first_batch
+
i
)
*
stride
]);
if
(
src_idx
<
idx_max_v
[
i
])
{
VecT
srctmp
=
src_v
[
src_idx
];
const
T
*
srcinptr
=
reinterpret_cast
<
const
T
*>
(
&
srctmp
);
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
s
++
)
{
srcdata
[
i
][
it
][
s
]
=
static_cast
<
AccT
>
(
srcinptr
[
s
]);
}
}
else
{
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
s
++
)
{
srcdata
[
i
][
it
][
s
]
=
-
std
::
numeric_limits
<
AccT
>::
infinity
();
}
}
}
}
}
// compute max value
AccT
max_value
[
kBatchSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
// it = 0
AccT
valmax
=
srcdata
[
i
][
0
][
0
];
#pragma unroll
for
(
int
s
=
1
;
s
<
kVSize
;
++
s
)
{
valmax
=
(
valmax
>
srcdata
[
i
][
0
][
s
])
?
valmax
:
srcdata
[
i
][
0
][
s
];
}
max_value
[
i
]
=
valmax
;
// it = 1, 2, ...
#pragma unroll
for
(
int
it
=
1
;
it
<
kIterationsV
;
++
it
)
{
AccT
valmax
=
srcdata
[
i
][
it
][
0
];
#pragma unroll
for
(
int
s
=
1
;
s
<
kVSize
;
++
s
)
{
valmax
=
(
valmax
>
srcdata
[
i
][
it
][
s
])
?
valmax
:
srcdata
[
i
][
it
][
s
];
}
max_value
[
i
]
=
(
max_value
[
i
]
>
valmax
)
?
max_value
[
i
]
:
valmax
;
}
}
WarpReduceMax
<
AccT
,
kBatchSize
,
kWarpSize
>
(
max_value
);
int
ptr
=
(
first_batch
+
i
)
*
stride
;
const
VecT
*
src_v
=
reinterpret_cast
<
const
VecT
*>
(
&
src
[
ptr
]);
VecT
*
reg_v
=
reinterpret_cast
<
VecT
*>
(
&
src_tmp
[
i
][
0
][
0
]);
kps
::
ReadData
<
VecT
,
VecT
,
kLoopsV
,
1
,
1
,
true
>
(
&
reg_v
[
0
],
&
src_v
[
0
],
idx_max_v
[
i
],
0
,
kWarpSize
,
1
);
kps
::
ElementwiseUnary
<
T
,
AccT
,
kVItem
,
1
,
1
,
DataTransFunctor
<
T
,
AccT
>>
(
&
srcdata
[
i
][
0
][
0
],
&
src_tmp
[
i
][
0
][
0
],
DataTransFunctor
<
T
,
AccT
>
());
}
// compute max
AccT
max
[
kBatchSize
];
kps
::
Init
<
AccT
,
kBatchSize
>
(
&
max
[
0
],
kLowInf
);
kps
::
Reduce
<
AccT
,
kVItem
,
kBatchSize
,
1
,
ReduceMaxFunctor
<
AccT
>
,
kMode
::
kLocalMode
>
(
&
max
[
0
],
&
srcdata
[
0
][
0
][
0
],
ReduceMaxFunctor
<
AccT
>
(),
true
);
WarpReduceMax
<
AccT
,
kBatchSize
,
kWarpSize
>
(
max
);
// compute sum
AccT
sum
[
kBatchSize
];
#pragma unroll
AccT
sum
[
kBatchSize
]
=
{
0
};
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
// it = 0
if
(
LogMode
)
{
sum
[
i
]
=
std
::
exp
(
srcdata
[
i
][
0
][
0
]
-
max_value
[
i
]);
}
else
{
srcdata
[
i
][
0
][
0
]
=
std
::
exp
(
srcdata
[
i
][
0
][
0
]
-
max_value
[
i
]);
sum
[
i
]
=
srcdata
[
i
][
0
][
0
];
}
#pragma unroll
for
(
int
s
=
1
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
sum
[
i
]
+=
std
::
exp
(
srcdata
[
i
][
0
][
s
]
-
max_value
[
i
]);
}
else
{
srcdata
[
i
][
0
][
s
]
=
std
::
exp
(
srcdata
[
i
][
0
][
s
]
-
max_value
[
i
]);
sum
[
i
]
+=
srcdata
[
i
][
0
][
s
];
}
}
// it = 1, 2, ...
#pragma unroll
for
(
int
it
=
1
;
it
<
kIterationsV
;
++
it
)
{
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
sum
[
i
]
+=
std
::
exp
(
srcdata
[
i
][
it
][
s
]
-
max_value
[
i
]);
}
else
{
srcdata
[
i
][
it
][
s
]
=
std
::
exp
(
srcdata
[
i
][
it
][
s
]
-
max_value
[
i
]);
sum
[
i
]
+=
srcdata
[
i
][
it
][
s
];
}
}
}
kps
::
ElementwiseUnary
<
AccT
,
AccT
,
kVItem
,
1
,
1
,
ExpSubFunctor
<
AccT
>>
(
&
srcdata
[
i
][
0
][
0
],
&
srcdata
[
i
][
0
][
0
],
ExpSubFunctor
<
AccT
>
(
max
[
i
]));
}
kps
::
Reduce
<
AccT
,
kVItem
,
kBatchSize
,
1
,
kps
::
AddFunctor
<
AccT
>
,
kMode
::
kLocalMode
>
(
&
sum
[
0
],
&
srcdata
[
0
][
0
][
0
],
kps
::
AddFunctor
<
AccT
>
(),
true
);
WarpReduceSum
<
AccT
,
kBatchSize
,
kWarpSize
>
(
sum
);
// write result to global memory
// write result to global memory
T
out_tmp
[
kBatchSize
][
kLoopsV
][
kVSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
if
(
LogMode
)
{
sum
[
i
]
=
std
::
log
(
sum
[
i
]);
}
#pragma unroll
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
int
idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
if
(
kVSize
==
1
)
{
if
(
idx
<
idx_max_v
[
i
])
{
if
(
LogMode
)
{
softmax
[(
first_batch
+
i
)
*
stride
+
idx
]
=
srcdata
[
i
][
it
][
0
]
-
max_value
[
i
]
-
sum
[
i
];
}
else
{
softmax
[(
first_batch
+
i
)
*
stride
+
idx
]
=
srcdata
[
i
][
it
][
0
]
/
sum
[
i
];
}
}
else
{
break
;
}
}
else
{
VecT
*
softmax_v
=
reinterpret_cast
<
VecT
*>
(
&
softmax
[(
first_batch
+
i
)
*
stride
]);
VecT
tmpdata
;
T
*
tmpptr
=
reinterpret_cast
<
T
*>
(
&
tmpdata
);
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
tmpptr
[
s
]
=
srcdata
[
i
][
it
][
s
]
-
max_value
[
i
]
-
sum
[
i
];
}
else
{
tmpptr
[
s
]
=
srcdata
[
i
][
it
][
s
]
/
sum
[
i
];
}
}
if
(
idx
<
idx_max_v
[
i
])
{
softmax_v
[
idx
]
=
tmpdata
;
}
else
{
break
;
}
}
}
kps
::
ElementwiseUnary
<
AccT
,
T
,
kVItem
,
1
,
1
,
UnaryDivFunctor
<
AccT
>>
(
&
out_tmp
[
i
][
0
][
0
],
&
srcdata
[
i
][
0
][
0
],
UnaryDivFunctor
<
AccT
>
(
sum
[
i
]));
int
softmax_ptr
=
(
first_batch
+
i
)
*
stride
;
VecT
*
softmax_v
=
reinterpret_cast
<
VecT
*>
(
&
softmax
[
softmax_ptr
]);
VecT
*
reg_v
=
reinterpret_cast
<
VecT
*>
(
&
out_tmp
[
i
][
0
][
0
]);
kps
::
WriteData
<
VecT
,
VecT
,
kLoopsV
,
1
,
1
,
true
>
(
&
softmax_v
[
0
],
&
reg_v
[
0
],
idx_max_v
[
i
],
0
,
kWarpSize
,
1
);
}
}
...
...
@@ -293,101 +292,82 @@ __global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src,
constexpr
int
kVSize
=
sizeof
(
VecT
)
/
sizeof
(
T
);
constexpr
int
kDimCeil
=
1
<<
Log2Elements
;
constexpr
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
constexpr
int
k
Iteration
s
=
kDimCeil
/
kWarpSize
;
constexpr
int
k
Loop
s
=
kDimCeil
/
kWarpSize
;
constexpr
int
kBatchSize
=
(
kDimCeil
<=
128
)
?
2
:
1
;
constexpr
int
kIterationsV
=
(
kIterations
>=
kVSize
)
?
(
kIterations
/
kVSize
)
:
1
;
constexpr
int
kLoopsV
=
(
kLoops
>=
kVSize
)
?
(
kLoops
/
kVSize
)
:
1
;
int
element_count_v
=
element_count
/
kVSize
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
kBatchSize
;
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
kBatchSize
)
{
local_batches
=
kBatchSize
;
}
// read data from global memory
VecT
src_reg
[
kBatchSize
][
kIterationsV
];
VecT
grad_reg
[
kBatchSize
][
kIterationsV
];
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
const
VecT
*
src_v
=
reinterpret_cast
<
const
VecT
*>
(
&
src
[(
first_batch
+
i
)
*
stride
]);
const
VecT
*
grad_v
=
reinterpret_cast
<
const
VecT
*>
(
&
grad
[(
first_batch
+
i
)
*
stride
]);
int
local_batches
=
min
(
batch_size
-
first_batch
,
kBatchSize
);
// max index to read
int
idx_max
=
(
i
<
local_batches
)
?
element_count
:
0
;
int
idx_max_v
=
idx_max
/
kVSize
;
// read data
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
int
src_idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
if
(
src_idx
<
idx_max_v
)
{
src_reg
[
i
][
it
]
=
src_v
[
src_idx
];
grad_reg
[
i
][
it
]
=
grad_v
[
src_idx
];
}
else
{
int
idx_max_v
[
kBatchSize
];
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
s
++
)
{
reinterpret_cast
<
T
*>
(
&
src_reg
[
i
][
it
])[
s
]
=
0.0
;
reinterpret_cast
<
T
*>
(
&
grad_reg
[
i
][
it
])[
s
]
=
0.0
;
}
}
for
(
int
i
=
0
;
i
<
kBatchSize
;
i
++
)
{
int
idx_max
=
((
i
+
first_batch
)
<
batch_size
)
?
element_count
:
0
;
idx_max_v
[
i
]
=
idx_max
/
kVSize
;
}
// read data from global memory
VecT
src_reg
[
kBatchSize
][
kLoopsV
];
VecT
grad_reg
[
kBatchSize
][
kLoopsV
];
VecT
k_value
;
for
(
int
s
=
0
;
s
<
kVSize
;
s
++
)
{
reinterpret_cast
<
T
*>
(
&
k_value
)[
s
]
=
0.0
;
}
kps
::
Init
<
VecT
,
kBatchSize
*
kLoopsV
>
(
&
src_reg
[
0
][
0
],
k_value
);
kps
::
Init
<
VecT
,
kBatchSize
*
kLoopsV
>
(
&
grad_reg
[
0
][
0
],
k_value
);
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
int
flag
=
i
<
local_batches
?
1
:
0
;
int
ptr
=
(
first_batch
+
i
)
*
stride
;
const
VecT
*
src_v
=
reinterpret_cast
<
const
VecT
*>
(
&
src
[
ptr
]);
const
VecT
*
grad_v
=
reinterpret_cast
<
const
VecT
*>
(
&
grad
[
ptr
]);
kps
::
ReadData
<
VecT
,
VecT
,
kLoopsV
,
1
,
1
,
true
>
(
&
src_reg
[
i
][
0
],
&
src_v
[
0
],
idx_max_v
[
i
],
0
,
kWarpSize
,
flag
);
kps
::
ReadData
<
VecT
,
VecT
,
kLoopsV
,
1
,
1
,
true
>
(
&
grad_reg
[
i
][
0
],
&
grad_v
[
0
],
idx_max_v
[
i
],
0
,
kWarpSize
,
flag
);
}
// change T to AccT
AccT
src_tmp
[
kBatchSize
][
kLoopsV
][
kVSize
];
AccT
grad_tmp
[
kBatchSize
][
kLoopsV
][
kVSize
];
const
T
*
src_ptr
=
reinterpret_cast
<
const
T
*>
(
&
src_reg
[
0
][
0
]);
const
T
*
grad_ptr
=
reinterpret_cast
<
const
T
*>
(
&
grad_reg
[
0
][
0
]);
constexpr
int
kStep
=
kBatchSize
*
kLoopsV
*
kVSize
;
constexpr
int
kVItem
=
kLoopsV
*
kVSize
;
kps
::
ElementwiseUnary
<
T
,
AccT
,
kStep
,
1
,
1
,
DataTransFunctor
<
T
,
AccT
>>
(
&
src_tmp
[
0
][
0
][
0
],
&
src_ptr
[
0
],
DataTransFunctor
<
T
,
AccT
>
());
kps
::
ElementwiseUnary
<
T
,
AccT
,
kStep
,
1
,
1
,
DataTransFunctor
<
T
,
AccT
>>
(
&
grad_tmp
[
0
][
0
][
0
],
&
grad_ptr
[
0
],
DataTransFunctor
<
T
,
AccT
>
());
// compute sum
AccT
sum
[
kBatchSize
]{
0.0
};
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
T
*
gradptr
=
reinterpret_cast
<
T
*>
(
&
grad_reg
[
i
][
it
]);
T
*
srcptr
=
reinterpret_cast
<
T
*>
(
&
src_reg
[
i
][
it
]);
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
sum
[
i
]
+=
static_cast
<
AccT
>
(
gradptr
[
s
]);
}
else
{
sum
[
i
]
+=
static_cast
<
AccT
>
(
gradptr
[
s
]
*
srcptr
[
s
]);
}
}
}
}
AccT
sum_tmp
[
kBatchSize
][
kLoopsV
][
kVSize
];
AccT
*
gradptr
=
reinterpret_cast
<
AccT
*>
(
&
grad_tmp
[
0
][
0
][
0
]);
AccT
*
srcptr
=
reinterpret_cast
<
AccT
*>
(
&
src_tmp
[
0
][
0
][
0
]);
kps
::
ElementwiseBinary
<
AccT
,
AccT
,
kStep
,
1
,
1
,
kps
::
MulFunctor
<
AccT
>>
(
&
sum_tmp
[
0
][
0
][
0
],
&
gradptr
[
0
],
&
srcptr
[
0
],
kps
::
MulFunctor
<
AccT
>
());
kps
::
Reduce
<
AccT
,
kVItem
,
kBatchSize
,
1
,
kps
::
AddFunctor
<
AccT
>
,
kps
::
details
::
ReduceMode
::
kLocalMode
>
(
&
sum
[
0
],
&
sum_tmp
[
0
][
0
][
0
],
kps
::
AddFunctor
<
AccT
>
(),
true
);
WarpReduceSum
<
AccT
,
kBatchSize
,
kWarpSize
>
(
sum
);
// write result
// write result to global memory
AccT
out
[
kBatchSize
][
kLoopsV
][
kVSize
];
T
out_tmp
[
kBatchSize
][
kLoopsV
][
kVSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
AccT
*
gradptr
=
reinterpret_cast
<
AccT
*>
(
&
grad_tmp
[
i
][
0
][
0
]);
AccT
*
srcptr
=
reinterpret_cast
<
AccT
*>
(
&
src_tmp
[
i
][
0
][
0
]);
kps
::
ElementwiseUnary
<
AccT
,
AccT
,
kVItem
,
1
,
1
,
UnarySubFunctor
<
AccT
>>
(
&
out
[
i
][
0
][
0
],
&
gradptr
[
0
],
UnarySubFunctor
<
AccT
>
(
sum
[
i
]));
kps
::
ElementwiseBinary
<
AccT
,
T
,
kVItem
,
1
,
1
,
kps
::
MulFunctor
<
AccT
>>
(
&
out_tmp
[
i
][
0
][
0
],
&
srcptr
[
0
],
&
out
[
i
][
0
][
0
],
kps
::
MulFunctor
<
AccT
>
());
VecT
*
dst_v
=
reinterpret_cast
<
VecT
*>
(
&
dst
[(
first_batch
+
i
)
*
stride
]);
// max index to write
int
idx_max
=
(
i
<
local_batches
)
?
element_count
:
0
;
int
idx_max_v
=
idx_max
/
kVSize
;
#pragma unroll
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
VecT
tmpdata
;
T
*
tmpptr
=
reinterpret_cast
<
T
*>
(
&
tmpdata
);
T
*
gradptr
=
reinterpret_cast
<
T
*>
(
&
grad_reg
[
i
][
it
]);
T
*
srcptr
=
reinterpret_cast
<
T
*>
(
&
src_reg
[
i
][
it
]);
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
tmpptr
[
s
]
=
static_cast
<
AccT
>
(
gradptr
[
s
])
-
std
::
exp
(
static_cast
<
AccT
>
(
srcptr
[
s
]))
*
sum
[
i
];
}
else
{
tmpptr
[
s
]
=
static_cast
<
AccT
>
(
srcptr
[
s
])
*
(
static_cast
<
AccT
>
(
gradptr
[
s
])
-
sum
[
i
]);
}
}
int
idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
if
(
idx
<
idx_max_v
)
{
dst_v
[
idx
]
=
tmpdata
;
}
}
VecT
*
reg_v
=
reinterpret_cast
<
VecT
*>
(
&
out_tmp
[
i
][
0
][
0
]);
kps
::
WriteData
<
VecT
,
VecT
,
kLoopsV
,
1
,
1
,
true
>
(
&
dst_v
[
0
],
&
reg_v
[
0
],
idx_max_v
[
i
],
0
,
kWarpSize
,
1
);
}
}
...
...
@@ -493,6 +473,7 @@ void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
// vectorization read/write
using
T4
=
typename
VecT4
<
T
>::
Type
;
using
T2
=
typename
VecT2
<
T
>::
Type
;
if
(
dim
%
4
==
0
)
{
SwitchWarpSoftmaxForward
<
T
,
T4
,
LogMode
>
(
blocks
,
threads
,
dev_ctx
,
out_data
,
x
.
data
<
T
>
(),
N
,
dim
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录