Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9b016c7c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2323
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9b016c7c
编写于
3月 01, 2021
作者:
Q
Qi Li
提交者:
GitHub
3月 01, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[ROCM] update fluid operators for rocm (part2), test=develop (#31211)
上级
2fd999d9
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
163 addition
and
26 deletion
+163
-26
paddle/fluid/operators/distributed_ops/CMakeLists.txt
paddle/fluid/operators/distributed_ops/CMakeLists.txt
+1
-1
paddle/fluid/operators/distributed_ops/allreduce_op.h
paddle/fluid/operators/distributed_ops/allreduce_op.h
+6
-2
paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc
paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc
+6
-2
paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.h
...le/fluid/operators/distributed_ops/ref_by_trainer_id_op.h
+1
-1
paddle/fluid/operators/kron_op.h
paddle/fluid/operators/kron_op.h
+16
-4
paddle/fluid/operators/matmul_v2_op.h
paddle/fluid/operators/matmul_v2_op.h
+7
-2
paddle/fluid/operators/prelu_op.cu
paddle/fluid/operators/prelu_op.cu
+7
-1
paddle/fluid/operators/reduce_ops/CMakeLists.txt
paddle/fluid/operators/reduce_ops/CMakeLists.txt
+5
-1
paddle/fluid/operators/reduce_ops/cub_reduce.h
paddle/fluid/operators/reduce_ops/cub_reduce.h
+47
-5
paddle/fluid/operators/reduce_ops/reduce_mean_op.cu
paddle/fluid/operators/reduce_ops/reduce_mean_op.cu
+6
-0
paddle/fluid/operators/reduce_ops/reduce_sum_op.cu
paddle/fluid/operators/reduce_ops/reduce_sum_op.cu
+12
-0
paddle/fluid/operators/sequence_ops/sequence_mask_op.h
paddle/fluid/operators/sequence_ops/sequence_mask_op.h
+2
-2
paddle/fluid/operators/sequence_ops/sequence_reverse_op.h
paddle/fluid/operators/sequence_ops/sequence_reverse_op.h
+2
-2
paddle/fluid/operators/sequence_ops/sequence_softmax_cudnn_op.cu.cc
...id/operators/sequence_ops/sequence_softmax_cudnn_op.cu.cc
+9
-0
paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc
paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc
+2
-2
paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu
paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu
+28
-1
paddle/fluid/operators/trace_op.cu
paddle/fluid/operators/trace_op.cu
+6
-0
未找到文件。
paddle/fluid/operators/distributed_ops/CMakeLists.txt
浏览文件 @
9b016c7c
...
@@ -30,7 +30,7 @@ endforeach()
...
@@ -30,7 +30,7 @@ endforeach()
register_operators
(
EXCLUDES gen_nccl_id_op DEPS
${
DISTRIBUTE_DEPS
}
)
register_operators
(
EXCLUDES gen_nccl_id_op DEPS
${
DISTRIBUTE_DEPS
}
)
if
(
WITH_NCCL
)
if
(
WITH_NCCL
OR WITH_RCCL
)
set
(
DISTRIBUTE_DEPS
${
DISTRIBUTE_DEPS
}
nccl_common
)
set
(
DISTRIBUTE_DEPS
${
DISTRIBUTE_DEPS
}
nccl_common
)
endif
()
endif
()
...
...
paddle/fluid/operators/distributed_ops/allreduce_op.h
浏览文件 @
9b016c7c
...
@@ -21,7 +21,7 @@ limitations under the License. */
...
@@ -21,7 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#endif
...
@@ -36,7 +36,7 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
...
@@ -36,7 +36,7 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ
(
is_gpu_place
(
place
),
true
,
PADDLE_ENFORCE_EQ
(
is_gpu_place
(
place
),
true
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"AllReduce op can run on gpu place only for now."
));
"AllReduce op can run on gpu place only for now."
));
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
...
@@ -73,7 +73,11 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
...
@@ -73,7 +73,11 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
sendbuff
,
recvbuff
,
numel
,
static_cast
<
ncclDataType_t
>
(
dtype
),
red_type
,
sendbuff
,
recvbuff
,
numel
,
static_cast
<
ncclDataType_t
>
(
dtype
),
red_type
,
comm
,
stream
));
comm
,
stream
));
if
(
ctx
.
Attr
<
bool
>
(
"sync_mode"
))
{
if
(
ctx
.
Attr
<
bool
>
(
"sync_mode"
))
{
#ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipStreamSynchronize
(
stream
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
stream
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
stream
));
#endif
}
}
#else
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
...
...
paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc
浏览文件 @
9b016c7c
...
@@ -20,7 +20,7 @@ limitations under the License. */
...
@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#endif
...
@@ -39,7 +39,7 @@ class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
...
@@ -39,7 +39,7 @@ class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"The place of ExecutionContext should be CUDAPlace."
));
"The place of ExecutionContext should be CUDAPlace."
));
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
int
dev_id
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
ctx
.
GetPlace
()).
device
;
int
dev_id
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
ctx
.
GetPlace
()).
device
;
int
root_dev_id
=
ctx
.
Attr
<
int
>
(
"root"
);
int
root_dev_id
=
ctx
.
Attr
<
int
>
(
"root"
);
...
@@ -68,7 +68,11 @@ class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
...
@@ -68,7 +68,11 @@ class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
<<
" From "
<<
root_dev_id
<<
" to "
<<
dev_id
;
<<
" From "
<<
root_dev_id
<<
" to "
<<
dev_id
;
if
(
ctx
.
Attr
<
bool
>
(
"sync_mode"
))
{
if
(
ctx
.
Attr
<
bool
>
(
"sync_mode"
))
{
#ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipStreamSynchronize
(
stream
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
stream
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
stream
));
#endif
}
}
#else
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
...
...
paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.h
浏览文件 @
9b016c7c
...
@@ -30,7 +30,7 @@ class RefByTrainerIdKernel : public framework::OpKernel<T> {
...
@@ -30,7 +30,7 @@ class RefByTrainerIdKernel : public framework::OpKernel<T> {
int64_t
trainer_id
=
0
;
int64_t
trainer_id
=
0
;
auto
*
trainer_id_data
=
trainer_id_t
->
data
<
int64_t
>
();
auto
*
trainer_id_data
=
trainer_id_t
->
data
<
int64_t
>
();
if
(
platform
::
is_gpu_place
(
context
.
GetPlace
()))
{
if
(
platform
::
is_gpu_place
(
context
.
GetPlace
()))
{
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto
stream
=
context
.
cuda_device_context
().
stream
();
auto
stream
=
context
.
cuda_device_context
().
stream
();
memory
::
Copy
<>
(
platform
::
CPUPlace
(),
&
trainer_id
,
memory
::
Copy
<>
(
platform
::
CPUPlace
(),
&
trainer_id
,
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
context
.
GetPlace
()),
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
context
.
GetPlace
()),
...
...
paddle/fluid/operators/kron_op.h
浏览文件 @
9b016c7c
...
@@ -18,7 +18,7 @@ limitations under the License. */
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/for_range.h"
#if
__NVCC__
#if
defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "thrust/device_vector.h"
#include "thrust/device_vector.h"
#endif
#endif
...
@@ -87,7 +87,7 @@ struct KronOpFunctor {
...
@@ -87,7 +87,7 @@ struct KronOpFunctor {
const
int64_t
*
p_stride_x
=
nullptr
,
*
p_stride_y
=
nullptr
,
const
int64_t
*
p_stride_x
=
nullptr
,
*
p_stride_y
=
nullptr
,
*
p_stride_out
=
nullptr
,
*
p_shape_y
=
nullptr
;
*
p_stride_out
=
nullptr
,
*
p_shape_y
=
nullptr
;
#if
__NVCC__
#if
defined(__NVCC__) || defined(__HIPCC__)
thrust
::
device_vector
<
int64_t
>
d_stride_x
(
ndims
);
thrust
::
device_vector
<
int64_t
>
d_stride_x
(
ndims
);
thrust
::
device_vector
<
int64_t
>
d_stride_y
(
ndims
);
thrust
::
device_vector
<
int64_t
>
d_stride_y
(
ndims
);
thrust
::
device_vector
<
int64_t
>
d_stride_out
(
ndims
);
thrust
::
device_vector
<
int64_t
>
d_stride_out
(
ndims
);
...
@@ -326,7 +326,7 @@ struct KronGradOpFunctor {
...
@@ -326,7 +326,7 @@ struct KronGradOpFunctor {
const
int64_t
*
p_stride_y
=
nullptr
;
const
int64_t
*
p_stride_y
=
nullptr
;
const
int64_t
*
p_stride_dout
=
nullptr
;
const
int64_t
*
p_stride_dout
=
nullptr
;
const
int64_t
*
p_shape_y
=
nullptr
;
const
int64_t
*
p_shape_y
=
nullptr
;
#if
__NVCC__
#if
defined(__NVCC__) || defined(__HIPCC__)
thrust
::
device_vector
<
int64_t
>
d_stride_x
(
ndims
);
thrust
::
device_vector
<
int64_t
>
d_stride_x
(
ndims
);
thrust
::
device_vector
<
int64_t
>
d_stride_y
(
ndims
);
thrust
::
device_vector
<
int64_t
>
d_stride_y
(
ndims
);
thrust
::
device_vector
<
int64_t
>
d_stride_dout
(
ndims
);
thrust
::
device_vector
<
int64_t
>
d_stride_dout
(
ndims
);
...
@@ -369,7 +369,19 @@ struct KronGradOpFunctor {
...
@@ -369,7 +369,19 @@ struct KronGradOpFunctor {
for_range
(
func
);
for_range
(
func
);
// reduce_sum along aixs 1
// reduce_sum along aixs 1
#if __NVCC__
#ifdef __HIPCC__
auto
stream
=
dev_ctx
.
stream
();
// it is a cuda device_context
if
(
dx
)
{
TensorReduce
<
T
,
T
,
hipcub
::
Sum
,
IdentityFunctor
<
T
>>
(
dout_x
,
dx
,
{
1
},
static_cast
<
T
>
(
0
),
hipcub
::
Sum
(),
IdentityFunctor
<
T
>
(),
stream
);
}
if
(
dy
)
{
TensorReduce
<
T
,
T
,
hipcub
::
Sum
,
IdentityFunctor
<
T
>>
(
dout_y
,
dy
,
{
1
},
static_cast
<
T
>
(
0
),
hipcub
::
Sum
(),
IdentityFunctor
<
T
>
(),
stream
);
}
#elif defined(__NVCC__)
auto
stream
=
dev_ctx
.
stream
();
// it is a cuda device_context
auto
stream
=
dev_ctx
.
stream
();
// it is a cuda device_context
if
(
dx
)
{
if
(
dx
)
{
TensorReduce
<
T
,
T
,
cub
::
Sum
,
IdentityFunctor
<
T
>>
(
TensorReduce
<
T
,
T
,
cub
::
Sum
,
IdentityFunctor
<
T
>>
(
...
...
paddle/fluid/operators/matmul_v2_op.h
浏览文件 @
9b016c7c
...
@@ -25,7 +25,7 @@ limitations under the License. */
...
@@ -25,7 +25,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
#if
def __NVCC__
#if
defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#endif
#endif
...
@@ -45,7 +45,12 @@ template <typename DeviceContext, typename T>
...
@@ -45,7 +45,12 @@ template <typename DeviceContext, typename T>
void
ReduceSumForMatmulGrad
(
const
Tensor
*
input
,
Tensor
*
output
,
void
ReduceSumForMatmulGrad
(
const
Tensor
*
input
,
Tensor
*
output
,
const
std
::
vector
<
int
>&
reduce_dims
,
const
std
::
vector
<
int
>&
reduce_dims
,
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
{
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
{
#ifdef __NVCC__
#ifdef __HIPCC__
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
TensorReduce
<
T
,
T
,
hipcub
::
Sum
,
IdentityFunctor
<
T
>>
(
*
input
,
output
,
reduce_dims
,
static_cast
<
T
>
(
0
),
hipcub
::
Sum
(),
IdentityFunctor
<
T
>
(),
stream
);
#elif defined(__NVCC__)
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
TensorReduce
<
T
,
T
,
cub
::
Sum
,
IdentityFunctor
<
T
>>
(
TensorReduce
<
T
,
T
,
cub
::
Sum
,
IdentityFunctor
<
T
>>
(
*
input
,
output
,
reduce_dims
,
static_cast
<
T
>
(
0
),
cub
::
Sum
(),
*
input
,
output
,
reduce_dims
,
static_cast
<
T
>
(
0
),
cub
::
Sum
(),
...
...
paddle/fluid/operators/prelu_op.cu
浏览文件 @
9b016c7c
...
@@ -95,7 +95,7 @@ __global__ void PReluOpGradKernel(const T* x_ptr, const T* alpha_ptr,
...
@@ -95,7 +95,7 @@ __global__ void PReluOpGradKernel(const T* x_ptr, const T* alpha_ptr,
template
<
typename
T
>
template
<
typename
T
>
class
PreluOpGradFunctor
{
class
PreluOpGradFunctor
{
public:
public:
void
operator
()(
cuda
Stream_t
stream
,
const
T
*
x
,
const
T
*
alpha
,
const
T
*
dy
,
void
operator
()(
gpu
Stream_t
stream
,
const
T
*
x
,
const
T
*
alpha
,
const
T
*
dy
,
T
*
dx
,
T
*
dalpha
,
const
framework
::
DDim
&
input_dims
,
T
*
dx
,
T
*
dalpha
,
const
framework
::
DDim
&
input_dims
,
PRELU_MODE
mode
)
{
PRELU_MODE
mode
)
{
size_t
numel
=
1
;
size_t
numel
=
1
;
...
@@ -174,9 +174,15 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
...
@@ -174,9 +174,15 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
reduce_dims
.
push_back
(
i
);
reduce_dims
.
push_back
(
i
);
}
}
#ifdef __HIPCC__
TensorReduce
<
T
,
T
,
hipcub
::
Sum
,
IdentityFunctor
<
T
>>
(
dalpha_tmp
,
dalpha
,
reduce_dims
,
static_cast
<
T
>
(
0
),
hipcub
::
Sum
(),
IdentityFunctor
<
T
>
(),
stream
);
#else
TensorReduce
<
T
,
T
,
cub
::
Sum
,
IdentityFunctor
<
T
>>
(
TensorReduce
<
T
,
T
,
cub
::
Sum
,
IdentityFunctor
<
T
>>
(
dalpha_tmp
,
dalpha
,
reduce_dims
,
static_cast
<
T
>
(
0
),
cub
::
Sum
(),
dalpha_tmp
,
dalpha
,
reduce_dims
,
static_cast
<
T
>
(
0
),
cub
::
Sum
(),
IdentityFunctor
<
T
>
(),
stream
);
IdentityFunctor
<
T
>
(),
stream
);
#endif
}
}
};
};
...
...
paddle/fluid/operators/reduce_ops/CMakeLists.txt
浏览文件 @
9b016c7c
...
@@ -13,7 +13,7 @@ else()
...
@@ -13,7 +13,7 @@ else()
register_operators
()
register_operators
()
endif
()
endif
()
if
(
WITH_GPU
)
if
(
WITH_GPU
OR WITH_ROCM
)
file
(
GLOB OPS RELATIVE
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
"*.part.cu"
)
file
(
GLOB OPS RELATIVE
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
"*.part.cu"
)
string
(
REPLACE
".part.cu"
""
OPS
"
${
OPS
}
"
)
string
(
REPLACE
".part.cu"
""
OPS
"
${
OPS
}
"
)
...
@@ -38,3 +38,7 @@ if(WITH_GPU)
...
@@ -38,3 +38,7 @@ if(WITH_GPU)
nv_test
(
check_reduce_rank_test SRCS check_reduce_rank_test.cu DEPS tensor
)
nv_test
(
check_reduce_rank_test SRCS check_reduce_rank_test.cu DEPS tensor
)
endif
()
endif
()
endif
()
endif
()
if
(
WITH_ROCM
)
hip_test
(
check_reduce_rank_test SRCS check_reduce_rank_test.cu DEPS tensor
)
endif
()
paddle/fluid/operators/reduce_ops/cub_reduce.h
浏览文件 @
9b016c7c
...
@@ -20,7 +20,14 @@
...
@@ -20,7 +20,14 @@
#include <set>
#include <set>
#include <vector>
#include <vector>
#include <cub/cub.cuh> // NOLINT
#ifdef __NVCC__
#include "cub/cub.cuh" // NOLINT
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
#endif
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/tensor_util.h"
...
@@ -64,7 +71,12 @@ template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
...
@@ -64,7 +71,12 @@ template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
__global__
void
ReduceKernel2D
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
__global__
void
ReduceKernel2D
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
Ty
init
,
TransformOp
transformer
,
Ty
init
,
int
reduce_num
)
{
int
reduce_num
)
{
#ifdef __HIPCC__
__shared__
typename
hipcub
::
BlockReduce
<
Ty
,
BlockDim
>::
TempStorage
temp_storage
;
#else
__shared__
typename
cub
::
BlockReduce
<
Ty
,
BlockDim
>::
TempStorage
temp_storage
;
__shared__
typename
cub
::
BlockReduce
<
Ty
,
BlockDim
>::
TempStorage
temp_storage
;
#endif
int
idx_x
=
blockIdx
.
x
*
reduce_num
;
int
idx_x
=
blockIdx
.
x
*
reduce_num
;
int
idx_y
=
threadIdx
.
x
;
int
idx_y
=
threadIdx
.
x
;
Ty
reduce_var
=
init
;
Ty
reduce_var
=
init
;
...
@@ -73,8 +85,13 @@ __global__ void ReduceKernel2D(const Tx* x, Ty* y, ReduceOp reducer,
...
@@ -73,8 +85,13 @@ __global__ void ReduceKernel2D(const Tx* x, Ty* y, ReduceOp reducer,
reducer
(
reduce_var
,
static_cast
<
Ty
>
(
transformer
(
x
[
idx_x
+
idx_y
])));
reducer
(
reduce_var
,
static_cast
<
Ty
>
(
transformer
(
x
[
idx_x
+
idx_y
])));
__syncthreads
();
__syncthreads
();
#ifdef __HIPCC__
reduce_var
=
hipcub
::
BlockReduce
<
Ty
,
BlockDim
>
(
temp_storage
)
.
Reduce
(
reduce_var
,
reducer
);
#else
reduce_var
=
reduce_var
=
cub
::
BlockReduce
<
Ty
,
BlockDim
>
(
temp_storage
).
Reduce
(
reduce_var
,
reducer
);
cub
::
BlockReduce
<
Ty
,
BlockDim
>
(
temp_storage
).
Reduce
(
reduce_var
,
reducer
);
#endif
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
y
[
blockIdx
.
x
]
=
reduce_var
;
y
[
blockIdx
.
x
]
=
reduce_var
;
...
@@ -90,7 +107,12 @@ __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer,
...
@@ -90,7 +107,12 @@ __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer,
Array
<
int
,
ReduceRank
>
reduce_strides
,
Array
<
int
,
ReduceRank
>
reduce_strides
,
Array
<
int
,
Rank
-
ReduceRank
>
left_dim
,
Array
<
int
,
Rank
-
ReduceRank
>
left_dim
,
Array
<
int
,
Rank
-
ReduceRank
>
left_strides
)
{
Array
<
int
,
Rank
-
ReduceRank
>
left_strides
)
{
#ifdef __HIPCC__
__shared__
typename
hipcub
::
BlockReduce
<
Ty
,
BlockDim
>::
TempStorage
temp_storage
;
#else
__shared__
typename
cub
::
BlockReduce
<
Ty
,
BlockDim
>::
TempStorage
temp_storage
;
__shared__
typename
cub
::
BlockReduce
<
Ty
,
BlockDim
>::
TempStorage
temp_storage
;
#endif
Array
<
int
,
Rank
>
sub_index
;
Array
<
int
,
Rank
>
sub_index
;
int
left_idx
=
blockIdx
.
x
;
int
left_idx
=
blockIdx
.
x
;
for
(
int
i
=
0
;
i
<
Rank
-
ReduceRank
;
++
i
)
{
for
(
int
i
=
0
;
i
<
Rank
-
ReduceRank
;
++
i
)
{
...
@@ -122,8 +144,13 @@ __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer,
...
@@ -122,8 +144,13 @@ __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer,
}
}
__syncthreads
();
__syncthreads
();
#ifdef __HIPCC__
reduce_var
=
hipcub
::
BlockReduce
<
Ty
,
BlockDim
>
(
temp_storage
)
.
Reduce
(
reduce_var
,
reducer
);
#else
reduce_var
=
reduce_var
=
cub
::
BlockReduce
<
Ty
,
BlockDim
>
(
temp_storage
).
Reduce
(
reduce_var
,
reducer
);
cub
::
BlockReduce
<
Ty
,
BlockDim
>
(
temp_storage
).
Reduce
(
reduce_var
,
reducer
);
#endif
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
y
[
blockIdx
.
x
]
=
reduce_var
;
y
[
blockIdx
.
x
]
=
reduce_var
;
...
@@ -188,7 +215,7 @@ static void TensorReduceImpl(
...
@@ -188,7 +215,7 @@ static void TensorReduceImpl(
int
left_num
,
int
reduce_num
,
const
std
::
vector
<
int
>&
x_strides
,
int
left_num
,
int
reduce_num
,
const
std
::
vector
<
int
>&
x_strides
,
const
std
::
vector
<
int
>&
reduce_dim
,
const
std
::
vector
<
int
>&
reduce_strides
,
const
std
::
vector
<
int
>&
reduce_dim
,
const
std
::
vector
<
int
>&
reduce_strides
,
const
std
::
vector
<
int
>&
left_dim
,
const
std
::
vector
<
int
>&
left_strides
,
const
std
::
vector
<
int
>&
left_dim
,
const
std
::
vector
<
int
>&
left_strides
,
cuda
Stream_t
stream
)
{
gpu
Stream_t
stream
)
{
#define CUB_RANK_CASE(i, ...) \
#define CUB_RANK_CASE(i, ...) \
case i: { \
case i: { \
constexpr auto kRank = i; \
constexpr auto kRank = i; \
...
@@ -211,17 +238,32 @@ static void TensorReduceImpl(
...
@@ -211,17 +238,32 @@ static void TensorReduceImpl(
int
rank
=
x_strides
.
size
();
int
rank
=
x_strides
.
size
();
int
reduce_rank
=
reduce_strides
.
size
();
int
reduce_rank
=
reduce_strides
.
size
();
if
(
rank
==
reduce_rank
)
{
if
(
rank
==
reduce_rank
)
{
#ifdef __HIPCC__
hipcub
::
TransformInputIterator
<
Ty
,
TransformOp
,
const
Tx
*>
trans_x
(
x_data
,
transformer
);
#else
cub
::
TransformInputIterator
<
Ty
,
TransformOp
,
const
Tx
*>
trans_x
(
cub
::
TransformInputIterator
<
Ty
,
TransformOp
,
const
Tx
*>
trans_x
(
x_data
,
transformer
);
x_data
,
transformer
);
#endif
size_t
temp_storage_bytes
=
0
;
size_t
temp_storage_bytes
=
0
;
#ifdef __HIPCC__
hipcub
::
DeviceReduce
::
Reduce
(
nullptr
,
temp_storage_bytes
,
trans_x
,
y_data
,
reduce_num
,
reducer
,
init
,
stream
);
#else
cub
::
DeviceReduce
::
Reduce
(
nullptr
,
temp_storage_bytes
,
trans_x
,
y_data
,
cub
::
DeviceReduce
::
Reduce
(
nullptr
,
temp_storage_bytes
,
trans_x
,
y_data
,
reduce_num
,
reducer
,
init
,
stream
);
reduce_num
,
reducer
,
init
,
stream
);
#endif
framework
::
Tensor
tmp
;
framework
::
Tensor
tmp
;
auto
*
temp_storage
=
tmp
.
mutable_data
<
uint8_t
>
(
auto
*
temp_storage
=
tmp
.
mutable_data
<
uint8_t
>
(
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
temp_storage_bytes
)}),
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
temp_storage_bytes
)}),
place
);
place
);
#ifdef __HIPCC__
hipcub
::
DeviceReduce
::
Reduce
(
temp_storage
,
temp_storage_bytes
,
trans_x
,
y_data
,
reduce_num
,
reducer
,
init
,
stream
);
#else
cub
::
DeviceReduce
::
Reduce
(
temp_storage
,
temp_storage_bytes
,
trans_x
,
y_data
,
cub
::
DeviceReduce
::
Reduce
(
temp_storage
,
temp_storage_bytes
,
trans_x
,
y_data
,
reduce_num
,
reducer
,
init
,
stream
);
reduce_num
,
reducer
,
init
,
stream
);
#endif
return
;
return
;
}
}
if
(
rank
==
2
&&
reduce_rank
==
1
&&
reduce_dim
[
0
]
==
1
)
{
if
(
rank
==
2
&&
reduce_rank
==
1
&&
reduce_dim
[
0
]
==
1
)
{
...
@@ -280,7 +322,7 @@ template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
...
@@ -280,7 +322,7 @@ template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
void
TensorReduce
(
const
framework
::
Tensor
&
x
,
framework
::
Tensor
*
y
,
void
TensorReduce
(
const
framework
::
Tensor
&
x
,
framework
::
Tensor
*
y
,
std
::
vector
<
int
>
origin_reduce_dims
,
const
Ty
&
init
,
std
::
vector
<
int
>
origin_reduce_dims
,
const
Ty
&
init
,
const
ReduceOp
&
reducer
,
const
TransformOp
&
transformer
,
const
ReduceOp
&
reducer
,
const
TransformOp
&
transformer
,
cuda
Stream_t
stream
)
{
gpu
Stream_t
stream
)
{
auto
x_dim
=
framework
::
vectorize
<
int
>
(
x
.
dims
());
auto
x_dim
=
framework
::
vectorize
<
int
>
(
x
.
dims
());
std
::
vector
<
int
>
new_x_dim
,
new_reduce_dims
;
std
::
vector
<
int
>
new_x_dim
,
new_reduce_dims
;
int
is_reduced
=
0
;
int
is_reduced
=
0
;
...
@@ -362,11 +404,11 @@ struct TensorReduceFunctor {
...
@@ -362,11 +404,11 @@ struct TensorReduceFunctor {
const
double
&
init
;
const
double
&
init
;
const
ReduceOp
&
reducer
;
const
ReduceOp
&
reducer
;
const
TransformOp
&
transformer
;
const
TransformOp
&
transformer
;
cuda
Stream_t
stream
;
gpu
Stream_t
stream
;
TensorReduceFunctor
(
const
framework
::
Tensor
&
x
,
framework
::
Tensor
*
y
,
TensorReduceFunctor
(
const
framework
::
Tensor
&
x
,
framework
::
Tensor
*
y
,
std
::
vector
<
int
>
origin_reduce_dims
,
const
double
&
init
,
std
::
vector
<
int
>
origin_reduce_dims
,
const
double
&
init
,
const
ReduceOp
&
reducer
,
const
TransformOp
&
transformer
,
const
ReduceOp
&
reducer
,
const
TransformOp
&
transformer
,
cuda
Stream_t
stream
)
gpu
Stream_t
stream
)
:
x
(
x
),
:
x
(
x
),
y
(
y
),
y
(
y
),
origin_reduce_dims
(
origin_reduce_dims
),
origin_reduce_dims
(
origin_reduce_dims
),
...
...
paddle/fluid/operators/reduce_ops/reduce_mean_op.cu
浏览文件 @
9b016c7c
...
@@ -56,9 +56,15 @@ class ReduceMeanKernel : public framework::OpKernel<T> {
...
@@ -56,9 +56,15 @@ class ReduceMeanKernel : public framework::OpKernel<T> {
}
}
auto
stream
=
context
.
cuda_device_context
().
stream
();
auto
stream
=
context
.
cuda_device_context
().
stream
();
#ifdef PADDLE_WITH_HIP
TensorReduce
<
T
,
T
,
hipcub
::
Sum
,
DivideFunctor
<
T
>>
(
*
input
,
output
,
reduce_dims
,
static_cast
<
T
>
(
0
),
hipcub
::
Sum
(),
DivideFunctor
<
T
>
(
reduce_num
),
stream
);
#else
TensorReduce
<
T
,
T
,
cub
::
Sum
,
DivideFunctor
<
T
>>
(
TensorReduce
<
T
,
T
,
cub
::
Sum
,
DivideFunctor
<
T
>>
(
*
input
,
output
,
reduce_dims
,
static_cast
<
T
>
(
0
),
cub
::
Sum
(),
*
input
,
output
,
reduce_dims
,
static_cast
<
T
>
(
0
),
cub
::
Sum
(),
DivideFunctor
<
T
>
(
reduce_num
),
stream
);
DivideFunctor
<
T
>
(
reduce_num
),
stream
);
#endif
}
}
};
};
...
...
paddle/fluid/operators/reduce_ops/reduce_sum_op.cu
浏览文件 @
9b016c7c
...
@@ -56,13 +56,25 @@ class ReduceSumKernel : public framework::OpKernel<T> {
...
@@ -56,13 +56,25 @@ class ReduceSumKernel : public framework::OpKernel<T> {
if
(
out_dtype
>=
0
)
{
if
(
out_dtype
>=
0
)
{
framework
::
VisitDataTypeSmall
(
framework
::
VisitDataTypeSmall
(
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
out_dtype
),
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
out_dtype
),
#ifdef __HIPCC__
TensorReduceFunctor
<
T
,
hipcub
::
Sum
,
IdentityFunctor
<
T
>>
(
*
input
,
output
,
reduce_dims
,
static_cast
<
double
>
(
0.0
),
hipcub
::
Sum
(),
IdentityFunctor
<
T
>
(),
stream
));
#else
TensorReduceFunctor
<
T
,
cub
::
Sum
,
IdentityFunctor
<
T
>>
(
TensorReduceFunctor
<
T
,
cub
::
Sum
,
IdentityFunctor
<
T
>>
(
*
input
,
output
,
reduce_dims
,
static_cast
<
double
>
(
0.0
),
cub
::
Sum
(),
*
input
,
output
,
reduce_dims
,
static_cast
<
double
>
(
0.0
),
cub
::
Sum
(),
IdentityFunctor
<
T
>
(),
stream
));
IdentityFunctor
<
T
>
(),
stream
));
#endif
}
else
{
}
else
{
#ifdef __HIPCC__
TensorReduce
<
T
,
T
,
hipcub
::
Sum
,
IdentityFunctor
<
T
>>
(
*
input
,
output
,
reduce_dims
,
static_cast
<
T
>
(
0
),
hipcub
::
Sum
(),
IdentityFunctor
<
T
>
(),
stream
);
#else
TensorReduce
<
T
,
T
,
cub
::
Sum
,
IdentityFunctor
<
T
>>
(
TensorReduce
<
T
,
T
,
cub
::
Sum
,
IdentityFunctor
<
T
>>
(
*
input
,
output
,
reduce_dims
,
static_cast
<
T
>
(
0
),
cub
::
Sum
(),
*
input
,
output
,
reduce_dims
,
static_cast
<
T
>
(
0
),
cub
::
Sum
(),
IdentityFunctor
<
T
>
(),
stream
);
IdentityFunctor
<
T
>
(),
stream
);
#endif
}
}
}
}
};
};
...
...
paddle/fluid/operators/sequence_ops/sequence_mask_op.h
浏览文件 @
9b016c7c
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#pragma once
#pragma once
#if
def __NVCC__
#if
defined(__NVCC__) || defined(__HIPCC__)
#include <thrust/device_ptr.h>
#include <thrust/device_ptr.h>
#include <thrust/functional.h>
#include <thrust/functional.h>
#include <thrust/reduce.h>
#include <thrust/reduce.h>
...
@@ -107,7 +107,7 @@ class SequenceMaskKernel : public framework::OpKernel<Tx> {
...
@@ -107,7 +107,7 @@ class SequenceMaskKernel : public framework::OpKernel<Tx> {
auto
*
x_data
=
x
->
data
<
Tx
>
();
auto
*
x_data
=
x
->
data
<
Tx
>
();
auto
x_numel
=
x
->
numel
();
auto
x_numel
=
x
->
numel
();
if
(
maxlen
<
0
)
{
if
(
maxlen
<
0
)
{
#if
def __NVCC__
#if
defined(__NVCC__) || defined(__HIPCC__)
VLOG
(
10
)
VLOG
(
10
)
<<
"SequenceMaskOp on GPU may be slow when maxlen is not provided."
;
<<
"SequenceMaskOp on GPU may be slow when maxlen is not provided."
;
maxlen
=
static_cast
<
int
>
(
maxlen
=
static_cast
<
int
>
(
...
...
paddle/fluid/operators/sequence_ops/sequence_reverse_op.h
浏览文件 @
9b016c7c
...
@@ -130,13 +130,13 @@ class SequenceReverseOpKernel : public framework::OpKernel<T> {
...
@@ -130,13 +130,13 @@ class SequenceReverseOpKernel : public framework::OpKernel<T> {
const
size_t
*
lod
;
const
size_t
*
lod
;
size_t
lod_count
=
x
.
lod
()[
0
].
size
();
size_t
lod_count
=
x
.
lod
()[
0
].
size
();
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
lod
=
x
.
lod
()[
0
].
CUDAData
(
ctx
.
GetPlace
());
lod
=
x
.
lod
()[
0
].
CUDAData
(
ctx
.
GetPlace
());
}
else
{
}
else
{
#endif
#endif
lod
=
x
.
lod
()[
0
].
data
();
lod
=
x
.
lod
()[
0
].
data
();
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
}
}
#endif
#endif
...
...
paddle/fluid/operators/sequence_ops/sequence_softmax_cudnn_op.cu.cc
浏览文件 @
9b016c7c
...
@@ -104,9 +104,18 @@ class SequenceSoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
...
@@ -104,9 +104,18 @@ class SequenceSoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
#ifdef PADDLE_WITH_HIP
// MIOPEN not support float64
REGISTER_OP_KERNEL
(
sequence_softmax
,
CUDNN
,
::
paddle
::
platform
::
CUDAPlace
,
ops
::
SequenceSoftmaxCUDNNKernel
<
float
>
);
REGISTER_OP_KERNEL
(
sequence_softmax_grad
,
CUDNN
,
::
paddle
::
platform
::
CUDAPlace
,
ops
::
SequenceSoftmaxGradCUDNNKernel
<
float
>
);
#else
REGISTER_OP_KERNEL
(
sequence_softmax
,
CUDNN
,
::
paddle
::
platform
::
CUDAPlace
,
REGISTER_OP_KERNEL
(
sequence_softmax
,
CUDNN
,
::
paddle
::
platform
::
CUDAPlace
,
ops
::
SequenceSoftmaxCUDNNKernel
<
float
>
,
ops
::
SequenceSoftmaxCUDNNKernel
<
float
>
,
ops
::
SequenceSoftmaxCUDNNKernel
<
double
>
);
ops
::
SequenceSoftmaxCUDNNKernel
<
double
>
);
REGISTER_OP_KERNEL
(
sequence_softmax_grad
,
CUDNN
,
::
paddle
::
platform
::
CUDAPlace
,
REGISTER_OP_KERNEL
(
sequence_softmax_grad
,
CUDNN
,
::
paddle
::
platform
::
CUDAPlace
,
ops
::
SequenceSoftmaxGradCUDNNKernel
<
float
>
,
ops
::
SequenceSoftmaxGradCUDNNKernel
<
float
>
,
ops
::
SequenceSoftmaxGradCUDNNKernel
<
double
>
);
ops
::
SequenceSoftmaxGradCUDNNKernel
<
double
>
);
#endif
paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc
浏览文件 @
9b016c7c
...
@@ -36,7 +36,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
...
@@ -36,7 +36,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
// choose cudnn kernel if the runtime supported.
// choose cudnn kernel if the runtime supported.
bool
use_cudnn
=
ctx
.
Attr
<
bool
>
(
"use_cudnn"
);
bool
use_cudnn
=
ctx
.
Attr
<
bool
>
(
"use_cudnn"
);
bool
runtime_cudnn_support
=
false
;
bool
runtime_cudnn_support
=
false
;
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
auto
&
dev_ctx
=
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
...
@@ -132,7 +132,7 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
...
@@ -132,7 +132,7 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
// choose cudnn kernel if the runtime supported.
// choose cudnn kernel if the runtime supported.
bool
use_cudnn
=
ctx
.
Attr
<
bool
>
(
"use_cudnn"
);
bool
use_cudnn
=
ctx
.
Attr
<
bool
>
(
"use_cudnn"
);
bool
runtime_cudnn_support
=
false
;
bool
runtime_cudnn_support
=
false
;
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
auto
&
dev_ctx
=
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
...
...
paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu
浏览文件 @
9b016c7c
...
@@ -13,7 +13,15 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <algorithm>
#include <algorithm>
#include <cub/cub.cuh> // NOLINT
#ifdef __NVCC__
#include <cub/cub.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
#endif
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/sequence_ops/sequence_softmax_op.h"
#include "paddle/fluid/operators/sequence_ops/sequence_softmax_op.h"
...
@@ -23,7 +31,11 @@ namespace operators {
...
@@ -23,7 +31,11 @@ namespace operators {
using
LoDTensor
=
framework
::
LoDTensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
template
<
typename
T
,
int
BlockDim
>
template
<
typename
T
,
int
BlockDim
>
#ifdef __HIPCC__
using
BlockReduce
=
hipcub
::
BlockReduce
<
T
,
BlockDim
>
;
#else
using
BlockReduce
=
cub
::
BlockReduce
<
T
,
BlockDim
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
T
,
BlockDim
>
;
#endif
template
<
typename
T
,
int
BlockDim
>
template
<
typename
T
,
int
BlockDim
>
using
BlockReduceTempStorage
=
typename
BlockReduce
<
T
,
BlockDim
>::
TempStorage
;
using
BlockReduceTempStorage
=
typename
BlockReduce
<
T
,
BlockDim
>::
TempStorage
;
...
@@ -45,8 +57,13 @@ __global__ void sequence_softmax_kernel(const T *in_data, const size_t *ref_lod,
...
@@ -45,8 +57,13 @@ __global__ void sequence_softmax_kernel(const T *in_data, const size_t *ref_lod,
T
ele
=
in_data
[
start
+
tid
];
T
ele
=
in_data
[
start
+
tid
];
max_ele
=
max_ele
>
ele
?
max_ele
:
ele
;
max_ele
=
max_ele
>
ele
?
max_ele
:
ele
;
}
}
#ifdef __HIPCC__
max_ele
=
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
max_ele
,
hipcub
::
Max
());
#else
max_ele
=
max_ele
=
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
max_ele
,
cub
::
Max
());
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
max_ele
,
cub
::
Max
());
#endif
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
shared_max_data
=
max_ele
;
shared_max_data
=
max_ele
;
}
}
...
@@ -58,8 +75,13 @@ __global__ void sequence_softmax_kernel(const T *in_data, const size_t *ref_lod,
...
@@ -58,8 +75,13 @@ __global__ void sequence_softmax_kernel(const T *in_data, const size_t *ref_lod,
T
ele
=
in_data
[
start
+
tid
];
T
ele
=
in_data
[
start
+
tid
];
sum_data
+=
real_exp
(
ele
-
shared_max_data
);
sum_data
+=
real_exp
(
ele
-
shared_max_data
);
}
}
#ifdef __HIPCC__
sum_data
=
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
sum_data
,
hipcub
::
Sum
());
#else
sum_data
=
sum_data
=
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
sum_data
,
cub
::
Sum
());
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
sum_data
,
cub
::
Sum
());
#endif
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
shared_sum_data
=
sum_data
;
shared_sum_data
=
sum_data
;
}
}
...
@@ -94,7 +116,12 @@ __global__ void sequence_softmax_grad_kernel(const T *softmax_grad_data,
...
@@ -94,7 +116,12 @@ __global__ void sequence_softmax_grad_kernel(const T *softmax_grad_data,
T
s_d
=
softmax_data
[
idx
];
T
s_d
=
softmax_data
[
idx
];
result
+=
s_g_d
*
s_d
;
result
+=
s_g_d
*
s_d
;
}
}
#ifdef __HIPCC__
result
=
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
result
,
hipcub
::
Sum
());
#else
result
=
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
result
,
cub
::
Sum
());
result
=
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
result
,
cub
::
Sum
());
#endif
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
shared_data
=
result
;
shared_data
=
result
;
}
}
...
...
paddle/fluid/operators/trace_op.cu
浏览文件 @
9b016c7c
...
@@ -43,9 +43,15 @@ class TraceCUDAKernel : public framework::OpKernel<T> {
...
@@ -43,9 +43,15 @@ class TraceCUDAKernel : public framework::OpKernel<T> {
auto
stream
=
context
.
cuda_device_context
().
stream
();
auto
stream
=
context
.
cuda_device_context
().
stream
();
std
::
vector
<
int
>
reduce_dims
;
std
::
vector
<
int
>
reduce_dims
;
reduce_dims
.
push_back
(
out
->
dims
().
size
());
reduce_dims
.
push_back
(
out
->
dims
().
size
());
#ifdef __HIPCC__
TensorReduce
<
T
,
T
,
hipcub
::
Sum
,
IdentityFunctor
<
T
>>
(
diag
,
out
,
reduce_dims
,
static_cast
<
T
>
(
0
),
hipcub
::
Sum
(),
IdentityFunctor
<
T
>
(),
stream
);
#else
TensorReduce
<
T
,
T
,
cub
::
Sum
,
IdentityFunctor
<
T
>>
(
TensorReduce
<
T
,
T
,
cub
::
Sum
,
IdentityFunctor
<
T
>>
(
diag
,
out
,
reduce_dims
,
static_cast
<
T
>
(
0
),
cub
::
Sum
(),
diag
,
out
,
reduce_dims
,
static_cast
<
T
>
(
0
),
cub
::
Sum
(),
IdentityFunctor
<
T
>
(),
stream
);
IdentityFunctor
<
T
>
(),
stream
);
#endif
}
}
}
}
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录