Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
59940cb3
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
59940cb3
编写于
3月 02, 2021
作者:
Q
Qi Li
提交者:
GitHub
3月 02, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[ROCM] update fluid operators for rocm (part8), test=develop (#31309)
上级
5d7a8b05
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
687 addition
and
28 deletion
+687
-28
paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc
paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc
+5
-0
paddle/fluid/operators/grid_sampler_op.cc
paddle/fluid/operators/grid_sampler_op.cc
+5
-2
paddle/fluid/operators/group_norm_op.cu
paddle/fluid/operators/group_norm_op.cu
+19
-4
paddle/fluid/operators/index_select_op.cu
paddle/fluid/operators/index_select_op.cu
+16
-0
paddle/fluid/operators/inplace_abn_op.cu
paddle/fluid/operators/inplace_abn_op.cu
+9
-0
paddle/fluid/operators/instance_norm_op.cu
paddle/fluid/operators/instance_norm_op.cu
+117
-1
paddle/fluid/operators/layer_norm_op.cu
paddle/fluid/operators/layer_norm_op.cu
+38
-3
paddle/fluid/operators/layer_norm_op.h
paddle/fluid/operators/layer_norm_op.h
+4
-4
paddle/fluid/operators/lod_tensor_to_array_op.cc
paddle/fluid/operators/lod_tensor_to_array_op.cc
+1
-1
paddle/fluid/operators/matmul_op.cc
paddle/fluid/operators/matmul_op.cc
+13
-7
paddle/fluid/operators/mean_op.cu
paddle/fluid/operators/mean_op.cu
+6
-0
paddle/fluid/operators/merge_lod_tensor_op.cc
paddle/fluid/operators/merge_lod_tensor_op.cc
+1
-1
paddle/fluid/operators/miopen_lstm_cache.h
paddle/fluid/operators/miopen_lstm_cache.h
+141
-0
paddle/fluid/operators/miopen_rnn_cache.h
paddle/fluid/operators/miopen_rnn_cache.h
+267
-0
paddle/fluid/operators/modified_huber_loss_op.h
paddle/fluid/operators/modified_huber_loss_op.h
+2
-2
paddle/fluid/operators/multinomial_op.cu
paddle/fluid/operators/multinomial_op.cu
+19
-2
paddle/fluid/operators/nll_loss_op.cu
paddle/fluid/operators/nll_loss_op.cu
+8
-1
paddle/fluid/operators/norm_op.cu
paddle/fluid/operators/norm_op.cu
+6
-0
paddle/fluid/operators/norm_utils.cu.h
paddle/fluid/operators/norm_utils.cu.h
+10
-0
未找到文件。
paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc
浏览文件 @
59940cb3
...
@@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#ifndef PADDLE_WITH_HIP
// HIP not support cudnnSpatialTfGridGeneratorForward
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/cudnn_helper.h"
...
@@ -140,3 +143,5 @@ REGISTER_OP_KERNEL(grid_sampler, CUDNN, plat::CUDAPlace,
...
@@ -140,3 +143,5 @@ REGISTER_OP_KERNEL(grid_sampler, CUDNN, plat::CUDAPlace,
REGISTER_OP_KERNEL
(
grid_sampler_grad
,
CUDNN
,
plat
::
CUDAPlace
,
REGISTER_OP_KERNEL
(
grid_sampler_grad
,
CUDNN
,
plat
::
CUDAPlace
,
paddle
::
operators
::
CUDNNGridSampleGradOpKernel
<
float
>
,
paddle
::
operators
::
CUDNNGridSampleGradOpKernel
<
float
>
,
paddle
::
operators
::
CUDNNGridSampleGradOpKernel
<
double
>
);
paddle
::
operators
::
CUDNNGridSampleGradOpKernel
<
double
>
);
#endif // PADDLE_WITH_HIP
paddle/fluid/operators/grid_sampler_op.cc
浏览文件 @
59940cb3
...
@@ -20,6 +20,9 @@ limitations under the License. */
...
@@ -20,6 +20,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -71,7 +74,7 @@ class GridSampleOp : public framework::OperatorWithKernel {
...
@@ -71,7 +74,7 @@ class GridSampleOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
}
...
@@ -191,7 +194,7 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
...
@@ -191,7 +194,7 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
}
...
...
paddle/fluid/operators/group_norm_op.cu
浏览文件 @
59940cb3
...
@@ -12,9 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,9 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#ifdef __NVCC__
#include "cub/cub.cuh"
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
#endif
#include "paddle/fluid/operators/group_norm_op.h"
#include "paddle/fluid/operators/group_norm_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -39,10 +46,18 @@ enum GroupNormKernelFlags { kHasScale = 1, kHasBias = 2 };
...
@@ -39,10 +46,18 @@ enum GroupNormKernelFlags { kHasScale = 1, kHasBias = 2 };
template
<
typename
T
>
template
<
typename
T
>
__device__
__inline__
void
CudaAtomicAddWithWarp
(
T
*
sum
,
T
value
)
{
__device__
__inline__
void
CudaAtomicAddWithWarp
(
T
*
sum
,
T
value
)
{
#ifdef PADDLE_WITH_CUDA
typedef
cub
::
WarpReduce
<
T
>
WarpReduce
;
typedef
cub
::
WarpReduce
<
T
>
WarpReduce
;
#else
typedef
hipcub
::
WarpReduce
<
T
>
WarpReduce
;
#endif
typename
WarpReduce
::
TempStorage
temp_storage
;
typename
WarpReduce
::
TempStorage
temp_storage
;
value
=
WarpReduce
(
temp_storage
).
Sum
(
value
);
value
=
WarpReduce
(
temp_storage
).
Sum
(
value
);
#ifdef PADDLE_WITH_CUDA
if
(
cub
::
LaneId
()
==
0
)
platform
::
CudaAtomicAdd
(
sum
,
value
);
if
(
cub
::
LaneId
()
==
0
)
platform
::
CudaAtomicAdd
(
sum
,
value
);
#else
if
(
hipcub
::
LaneId
()
==
0
)
platform
::
CudaAtomicAdd
(
sum
,
value
);
#endif
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -217,10 +232,10 @@ __global__ void GroupNormBackwardGetMeanAndVar(
...
@@ -217,10 +232,10 @@ __global__ void GroupNormBackwardGetMeanAndVar(
d_bias_data
+=
dval
;
d_bias_data
+=
dval
;
d_scale_data
+=
val
*
dval
;
d_scale_data
+=
val
*
dval
;
}
}
CudaAtomicAddWithWarp
(
&
d_mean
[
bid
*
groups
+
gid
]
,
d_mean_data
);
CudaAtomicAddWithWarp
(
&
(
d_mean
[
bid
*
groups
+
gid
])
,
d_mean_data
);
CudaAtomicAddWithWarp
(
&
d_var
[
bid
*
groups
+
gid
]
,
d_var_data
);
CudaAtomicAddWithWarp
(
&
(
d_var
[
bid
*
groups
+
gid
])
,
d_var_data
);
if
(
flags
&
kHasScale
)
CudaAtomicAddWithWarp
(
&
d_scale
[
ccid
]
,
d_scale_data
);
if
(
flags
&
kHasScale
)
CudaAtomicAddWithWarp
(
&
(
d_scale
[
ccid
])
,
d_scale_data
);
if
(
flags
&
kHasBias
)
CudaAtomicAddWithWarp
(
&
d_bias
[
ccid
]
,
d_bias_data
);
if
(
flags
&
kHasBias
)
CudaAtomicAddWithWarp
(
&
(
d_bias
[
ccid
])
,
d_bias_data
);
}
}
template
<
typename
T
,
int
flags
>
template
<
typename
T
,
int
flags
>
...
...
paddle/fluid/operators/index_select_op.cu
浏览文件 @
59940cb3
...
@@ -106,14 +106,22 @@ class IndexSelectCUDAKernel : public framework::OpKernel<T> {
...
@@ -106,14 +106,22 @@ class IndexSelectCUDAKernel : public framework::OpKernel<T> {
(
numel
+
PADDLE_CUDA_NUM_THREADS
-
1
)
/
PADDLE_CUDA_NUM_THREADS
,
(
numel
+
PADDLE_CUDA_NUM_THREADS
-
1
)
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
in_data
,
out_data
,
index_data
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
in_data
,
out_data
,
index_data
,
numel
,
stride
,
size
,
delta
);
numel
,
stride
,
size
,
delta
);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipStreamSynchronize
(
stream
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
stream
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
stream
));
#endif
}
else
{
}
else
{
const
int
*
index_data
=
index
->
data
<
int
>
();
const
int
*
index_data
=
index
->
data
<
int
>
();
index_select_cuda_kernel
<
T
,
int
><<<
(
numel
+
PADDLE_CUDA_NUM_THREADS
-
1
)
/
index_select_cuda_kernel
<
T
,
int
><<<
(
numel
+
PADDLE_CUDA_NUM_THREADS
-
1
)
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
in_data
,
out_data
,
index_data
,
numel
,
stride
,
size
,
delta
);
in_data
,
out_data
,
index_data
,
numel
,
stride
,
size
,
delta
);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipStreamSynchronize
(
stream
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
stream
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
stream
));
#endif
}
}
}
}
};
};
...
@@ -164,7 +172,11 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {
...
@@ -164,7 +172,11 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
output_grad_data
,
in_grad_data
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
output_grad_data
,
in_grad_data
,
index_data
,
index_nums
,
numel
,
index_data
,
index_nums
,
numel
,
stride
,
size
,
delta
);
stride
,
size
,
delta
);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipStreamSynchronize
(
stream
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
stream
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
stream
));
#endif
}
else
{
}
else
{
const
int
*
index_data
=
index
->
data
<
int
>
();
const
int
*
index_data
=
index
->
data
<
int
>
();
index_select_grad_cuda_kernel
<
T
,
int
><<<
index_select_grad_cuda_kernel
<
T
,
int
><<<
...
@@ -172,7 +184,11 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {
...
@@ -172,7 +184,11 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
output_grad_data
,
in_grad_data
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
output_grad_data
,
in_grad_data
,
index_data
,
index_nums
,
numel
,
index_data
,
index_nums
,
numel
,
stride
,
size
,
delta
);
stride
,
size
,
delta
);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipStreamSynchronize
(
stream
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
stream
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
stream
));
#endif
}
}
}
}
};
};
...
...
paddle/fluid/operators/inplace_abn_op.cu
浏览文件 @
59940cb3
...
@@ -84,9 +84,18 @@ class InplaceABNGradKernel
...
@@ -84,9 +84,18 @@ class InplaceABNGradKernel
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
namespace
plat
=
paddle
::
platform
;
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL
(
inplace_abn
,
ops
::
InplaceABNKernel
<
plat
::
CUDADeviceContext
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
inplace_abn_grad
,
ops
::
InplaceABNGradKernel
<
plat
::
CUDADeviceContext
,
float
>
);
#else
REGISTER_OP_CUDA_KERNEL
(
inplace_abn
,
REGISTER_OP_CUDA_KERNEL
(
inplace_abn
,
ops
::
InplaceABNKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
InplaceABNKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
InplaceABNKernel
<
plat
::
CUDADeviceContext
,
double
>
);
ops
::
InplaceABNKernel
<
plat
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
inplace_abn_grad
,
ops
::
InplaceABNGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
inplace_abn_grad
,
ops
::
InplaceABNGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
InplaceABNGradKernel
<
plat
::
CUDADeviceContext
,
double
>
);
ops
::
InplaceABNGradKernel
<
plat
::
CUDADeviceContext
,
double
>
);
#endif
paddle/fluid/operators/instance_norm_op.cu
浏览文件 @
59940cb3
...
@@ -16,11 +16,22 @@ limitations under the License. */
...
@@ -16,11 +16,22 @@ limitations under the License. */
#include <cfloat>
#include <cfloat>
#include <string>
#include <string>
#include <vector>
#include <vector>
#ifdef __NVCC__
#include "cub/cub.cuh"
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/instance_norm_op.h"
#include "paddle/fluid/operators/instance_norm_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -99,6 +110,15 @@ class InstanceNormKernel<platform::CUDADeviceContext, T>
...
@@ -99,6 +110,15 @@ class InstanceNormKernel<platform::CUDADeviceContext, T>
auto
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
auto
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t
data_desc_
;
miopenTensorDescriptor_t
in_param_desc_
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
in_param_desc_
));
#else
cudnnTensorDescriptor_t
data_desc_
;
cudnnTensorDescriptor_t
data_desc_
;
cudnnTensorDescriptor_t
in_param_desc_
;
cudnnTensorDescriptor_t
in_param_desc_
;
...
@@ -106,7 +126,7 @@ class InstanceNormKernel<platform::CUDADeviceContext, T>
...
@@ -106,7 +126,7 @@ class InstanceNormKernel<platform::CUDADeviceContext, T>
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
data_desc_
));
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
in_param_desc_
));
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
in_param_desc_
));
#endif
if
(
epsilon
<=
CUDNN_BN_MIN_EPSILON
-
FLT_EPSILON
)
{
if
(
epsilon
<=
CUDNN_BN_MIN_EPSILON
-
FLT_EPSILON
)
{
LOG
(
ERROR
)
<<
"Provided epsilon is smaller than "
LOG
(
ERROR
)
<<
"Provided epsilon is smaller than "
<<
"CUDNN_BN_MIN_EPSILON. Setting it to "
<<
"CUDNN_BN_MIN_EPSILON. Setting it to "
...
@@ -122,12 +142,22 @@ class InstanceNormKernel<platform::CUDADeviceContext, T>
...
@@ -122,12 +142,22 @@ class InstanceNormKernel<platform::CUDADeviceContext, T>
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetTensorDescriptor
(
data_desc_
,
CudnnDataType
<
T
>::
type
,
x_dims
.
size
()
>
3
?
x_dims
.
size
()
:
4
,
const_cast
<
int
*>
(
dims
.
data
()),
const_cast
<
int
*>
(
strides
.
data
())));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDeriveBNTensorDescriptor
(
in_param_desc_
,
data_desc_
,
miopenBNSpatial
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
data_desc_
,
CudnnDataType
<
T
>::
type
,
data_desc_
,
CudnnDataType
<
T
>::
type
,
x_dims
.
size
()
>
3
?
x_dims
.
size
()
:
4
,
dims
.
data
(),
strides
.
data
()));
x_dims
.
size
()
>
3
?
x_dims
.
size
()
:
4
,
dims
.
data
(),
strides
.
data
()));
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnDeriveBNTensorDescriptor
(
platform
::
dynload
::
cudnnDeriveBNTensorDescriptor
(
in_param_desc_
,
data_desc_
,
CUDNN_BATCHNORM_SPATIAL
));
in_param_desc_
,
data_desc_
,
CUDNN_BATCHNORM_SPATIAL
));
#endif
const
auto
*
scale
=
ctx
.
Input
<
Tensor
>
(
"Scale"
);
const
auto
*
scale
=
ctx
.
Input
<
Tensor
>
(
"Scale"
);
const
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
const
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
...
@@ -171,6 +201,35 @@ class InstanceNormKernel<platform::CUDADeviceContext, T>
...
@@ -171,6 +201,35 @@ class InstanceNormKernel<platform::CUDADeviceContext, T>
functor
(
dev_ctx
,
saved_mean
,
static_cast
<
BatchNormParamType
<
T
>>
(
0
));
functor
(
dev_ctx
,
saved_mean
,
static_cast
<
BatchNormParamType
<
T
>>
(
0
));
functor
(
dev_ctx
,
saved_variance
,
static_cast
<
BatchNormParamType
<
T
>>
(
0
));
functor
(
dev_ctx
,
saved_variance
,
static_cast
<
BatchNormParamType
<
T
>>
(
0
));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenBatchNormalizationForwardTraining
(
handle
,
miopenBNSpatial
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
CudnnDataType
<
T
>::
kOne
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
CudnnDataType
<
T
>::
kZero
())),
data_desc_
,
static_cast
<
const
void
*>
(
x_tmp
.
template
data
<
T
>()),
data_desc_
,
static_cast
<
void
*>
(
y
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
())),
in_param_desc_
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
scale_tmp
.
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
bias_tmp
.
template
data
<
BatchNormParamType
<
T
>
>
())),
0
,
nullptr
,
nullptr
,
epsilon
,
static_cast
<
void
*>
(
saved_mean
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())),
static_cast
<
void
*>
(
saved_variance
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()))));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
in_param_desc_
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnBatchNormalizationForwardTraining
(
platform
::
dynload
::
cudnnBatchNormalizationForwardTraining
(
handle
,
CUDNN_BATCHNORM_SPATIAL
,
CudnnDataType
<
T
>::
kOne
(),
handle
,
CUDNN_BATCHNORM_SPATIAL
,
CudnnDataType
<
T
>::
kOne
(),
...
@@ -188,6 +247,7 @@ class InstanceNormKernel<platform::CUDADeviceContext, T>
...
@@ -188,6 +247,7 @@ class InstanceNormKernel<platform::CUDADeviceContext, T>
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
data_desc_
));
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
in_param_desc_
));
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
in_param_desc_
));
#endif
}
}
};
};
...
@@ -332,6 +392,15 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T>
...
@@ -332,6 +392,15 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T>
return
;
return
;
}
}
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t
data_desc_
;
miopenTensorDescriptor_t
in_param_desc_
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
in_param_desc_
));
#else
cudnnTensorDescriptor_t
data_desc_
;
cudnnTensorDescriptor_t
data_desc_
;
cudnnTensorDescriptor_t
in_param_desc_
;
cudnnTensorDescriptor_t
in_param_desc_
;
...
@@ -339,6 +408,8 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T>
...
@@ -339,6 +408,8 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T>
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
data_desc_
));
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
in_param_desc_
));
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
in_param_desc_
));
#endif
if
(
epsilon
<=
CUDNN_BN_MIN_EPSILON
-
FLT_EPSILON
)
{
if
(
epsilon
<=
CUDNN_BN_MIN_EPSILON
-
FLT_EPSILON
)
{
LOG
(
ERROR
)
<<
"Provided epsilon is smaller than "
LOG
(
ERROR
)
<<
"Provided epsilon is smaller than "
<<
"CUDNN_BN_MIN_EPSILON. Setting it to "
<<
"CUDNN_BN_MIN_EPSILON. Setting it to "
...
@@ -346,12 +417,22 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T>
...
@@ -346,12 +417,22 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T>
}
}
epsilon
=
std
::
max
(
epsilon
,
CUDNN_BN_MIN_EPSILON
);
epsilon
=
std
::
max
(
epsilon
,
CUDNN_BN_MIN_EPSILON
);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetTensorDescriptor
(
data_desc_
,
CudnnDataType
<
T
>::
type
,
x_dims
.
size
()
>
3
?
x_dims
.
size
()
:
4
,
const_cast
<
int
*>
(
dims
.
data
()),
const_cast
<
int
*>
(
strides
.
data
())));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDeriveBNTensorDescriptor
(
in_param_desc_
,
data_desc_
,
miopenBNSpatial
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
data_desc_
,
CudnnDataType
<
T
>::
type
,
data_desc_
,
CudnnDataType
<
T
>::
type
,
x_dims
.
size
()
>
3
?
x_dims
.
size
()
:
4
,
dims
.
data
(),
strides
.
data
()));
x_dims
.
size
()
>
3
?
x_dims
.
size
()
:
4
,
dims
.
data
(),
strides
.
data
()));
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnDeriveBNTensorDescriptor
(
platform
::
dynload
::
cudnnDeriveBNTensorDescriptor
(
in_param_desc_
,
data_desc_
,
CUDNN_BATCHNORM_SPATIAL
));
in_param_desc_
,
data_desc_
,
CUDNN_BATCHNORM_SPATIAL
));
#endif
const
auto
*
saved_mean
=
ctx
.
Input
<
Tensor
>
(
"SavedMean"
);
const
auto
*
saved_mean
=
ctx
.
Input
<
Tensor
>
(
"SavedMean"
);
const
auto
*
saved_var
=
ctx
.
Input
<
Tensor
>
(
"SavedVariance"
);
const
auto
*
saved_var
=
ctx
.
Input
<
Tensor
>
(
"SavedVariance"
);
...
@@ -360,6 +441,21 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T>
...
@@ -360,6 +441,21 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T>
const
auto
*
saved_var_data
=
const
auto
*
saved_var_data
=
saved_var
->
template
data
<
BatchNormParamType
<
T
>
>
();
saved_var
->
template
data
<
BatchNormParamType
<
T
>
>
();
if
(
d_scale
&&
d_bias
)
{
if
(
d_scale
&&
d_bias
)
{
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenBatchNormalizationBackward
(
dev_ctx
.
cudnn_handle
(),
miopenBNSpatial
,
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
x_tmp
.
template
data
<
T
>(),
data_desc_
,
d_y_tmp
.
template
data
<
T
>(),
data_desc_
,
d_x
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
in_param_desc_
,
scale_tmp
.
template
data
<
BatchNormParamType
<
T
>
>
(),
d_scale_tmp
.
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
d_bias_tmp
.
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
epsilon
,
saved_mean_data
,
saved_var_data
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnBatchNormalizationBackward
(
platform
::
dynload
::
cudnnBatchNormalizationBackward
(
dev_ctx
.
cudnn_handle
(),
CUDNN_BATCHNORM_SPATIAL
,
dev_ctx
.
cudnn_handle
(),
CUDNN_BATCHNORM_SPATIAL
,
...
@@ -373,6 +469,7 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T>
...
@@ -373,6 +469,7 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T>
d_bias_tmp
.
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
d_bias_tmp
.
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
ctx
.
GetPlace
()),
epsilon
,
saved_mean_data
,
saved_var_data
));
epsilon
,
saved_mean_data
,
saved_var_data
));
#endif
}
else
{
}
else
{
if
(
d_x
)
{
if
(
d_x
)
{
GradComputeDX
<
T
,
block
><<<
NxC
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
GradComputeDX
<
T
,
block
><<<
NxC
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
...
@@ -389,10 +486,17 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T>
...
@@ -389,10 +486,17 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T>
d_bias_tmp
.
data
<
T
>
(),
d_bias
->
data
<
T
>
(),
N
,
C
);
d_bias_tmp
.
data
<
T
>
(),
d_bias
->
data
<
T
>
(),
N
,
C
);
}
}
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
in_param_desc_
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
data_desc_
));
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
in_param_desc_
));
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
in_param_desc_
));
#endif
}
}
};
};
...
@@ -693,6 +797,17 @@ class InstanceNormDoubleGradKernel<platform::CUDADeviceContext, T>
...
@@ -693,6 +797,17 @@ class InstanceNormDoubleGradKernel<platform::CUDADeviceContext, T>
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
namespace
plat
=
paddle
::
platform
;
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL
(
instance_norm
,
ops
::
InstanceNormKernel
<
plat
::
CUDADeviceContext
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
instance_norm_grad
,
ops
::
InstanceNormGradKernel
<
plat
::
CUDADeviceContext
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
instance_norm_grad_grad
,
ops
::
InstanceNormDoubleGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
#else
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
instance_norm
,
ops
::
InstanceNormKernel
<
plat
::
CUDADeviceContext
,
float
>
,
instance_norm
,
ops
::
InstanceNormKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
InstanceNormKernel
<
plat
::
CUDADeviceContext
,
double
>
);
ops
::
InstanceNormKernel
<
plat
::
CUDADeviceContext
,
double
>
);
...
@@ -706,3 +821,4 @@ REGISTER_OP_CUDA_KERNEL(
...
@@ -706,3 +821,4 @@ REGISTER_OP_CUDA_KERNEL(
float
>
,
float
>
,
ops
::
InstanceNormDoubleGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
ops
::
InstanceNormDoubleGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
double
>
);
#endif
paddle/fluid/operators/layer_norm_op.cu
浏览文件 @
59940cb3
...
@@ -12,14 +12,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,14 +12,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <cub/cub.cuh>
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/operators/layer_norm_op.h"
#include "paddle/fluid/operators/layer_norm_op.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/float16.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -348,7 +359,11 @@ __global__ void LayerNormBackwardComputeGradInput(
...
@@ -348,7 +359,11 @@ __global__ void LayerNormBackwardComputeGradInput(
// epsilon, const T* gamma,
// epsilon, const T* gamma,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
var
,
const
float
epsilon
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
var
,
const
float
epsilon
,
const
U
*
gamma
,
T
*
grad_input
)
{
const
U
*
gamma
,
T
*
grad_input
)
{
#ifdef __HIPCC__
for
(
auto
i1
=
hipBlockIdx_y
;
i1
<
n1
;
i1
+=
hipGridDim_y
)
{
#else
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
#endif
U
sum_loss1
=
U
(
0
);
U
sum_loss1
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
const
U
c_mean
=
mean
[
i1
];
const
U
c_mean
=
mean
[
i1
];
...
@@ -392,12 +407,19 @@ __global__ void LayerNormBackwardComputeGradInput(
...
@@ -392,12 +407,19 @@ __global__ void LayerNormBackwardComputeGradInput(
}
}
// intra-warp reductions
// intra-warp reductions
for
(
int
mask
=
BDIMX
/
2
;
mask
>
0
;
mask
/=
2
)
{
for
(
int
mask
=
BDIMX
/
2
;
mask
>
0
;
mask
/=
2
)
{
#ifdef PADDLE_WITH_HIP
sum_loss1
+=
__shfl_xor
(
sum_loss1
,
mask
,
warpSize
);
// WARP_SHFL_XOR(sum_loss1, mask);
sum_loss2
+=
__shfl_xor
(
sum_loss2
,
mask
,
warpSize
);
// WARP_SHFL_XOR(sum_loss2, mask);
#else
sum_loss1
+=
sum_loss1
+=
__shfl_xor_sync
(
0xffffffff
,
sum_loss1
,
mask
,
__shfl_xor_sync
(
0xffffffff
,
sum_loss1
,
mask
,
warpSize
);
// WARP_SHFL_XOR(sum_loss1, mask);
warpSize
);
// WARP_SHFL_XOR(sum_loss1, mask);
sum_loss2
+=
sum_loss2
+=
__shfl_xor_sync
(
0xffffffff
,
sum_loss2
,
mask
,
__shfl_xor_sync
(
0xffffffff
,
sum_loss2
,
mask
,
warpSize
);
// WARP_SHFL_XOR(sum_loss2, mask);
warpSize
);
// WARP_SHFL_XOR(sum_loss2, mask);
#endif
}
}
// inter-warp reductions
// inter-warp reductions
if
(
BDIMY
>
1
)
{
if
(
BDIMY
>
1
)
{
...
@@ -821,7 +843,7 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
...
@@ -821,7 +843,7 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
}
}
template
<
typename
T
>
template
<
typename
T
>
void
LayerNormDirectCUDAFunctor
<
T
>::
operator
()(
cuda
Stream_t
stream
,
void
LayerNormDirectCUDAFunctor
<
T
>::
operator
()(
gpu
Stream_t
stream
,
const
T
*
input
,
const
T
*
input
,
std
::
vector
<
int
>
input_shape
,
std
::
vector
<
int
>
input_shape
,
const
T
*
bias
,
const
T
*
scale
,
const
T
*
bias
,
const
T
*
scale
,
...
@@ -942,6 +964,18 @@ template class LayerNormDirectCUDAFunctor<float>;
...
@@ -942,6 +964,18 @@ template class LayerNormDirectCUDAFunctor<float>;
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
namespace
plat
=
paddle
::
platform
;
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL
(
layer_norm
,
ops
::
LayerNormKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
LayerNormKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
layer_norm_grad
,
ops
::
LayerNormGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
LayerNormGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
);
#else
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
layer_norm
,
layer_norm
,
ops
::
LayerNormKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
LayerNormKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
...
@@ -953,3 +987,4 @@ REGISTER_OP_CUDA_KERNEL(
...
@@ -953,3 +987,4 @@ REGISTER_OP_CUDA_KERNEL(
ops
::
LayerNormGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
LayerNormGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
LayerNormGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
ops
::
LayerNormGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
);
plat
::
float16
>
);
#endif
paddle/fluid/operators/layer_norm_op.h
浏览文件 @
59940cb3
...
@@ -51,7 +51,7 @@ struct RowwiseMean2D {
...
@@ -51,7 +51,7 @@ struct RowwiseMean2D {
const
framework
::
Tensor
&
input
,
framework
::
Tensor
*
vec
);
const
framework
::
Tensor
&
input
,
framework
::
Tensor
*
vec
);
};
};
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template
<
typename
T
>
template
<
typename
T
>
class
RowwiseMean2D
<
platform
::
CUDADeviceContext
,
T
>
{
class
RowwiseMean2D
<
platform
::
CUDADeviceContext
,
T
>
{
public:
public:
...
@@ -97,7 +97,7 @@ struct ColwiseSum2D {
...
@@ -97,7 +97,7 @@ struct ColwiseSum2D {
const
framework
::
Tensor
&
input
,
framework
::
Tensor
*
vec
);
const
framework
::
Tensor
&
input
,
framework
::
Tensor
*
vec
);
};
};
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template
<
typename
T
>
template
<
typename
T
>
class
ColwiseSum2D
<
platform
::
CUDADeviceContext
,
T
>
{
class
ColwiseSum2D
<
platform
::
CUDADeviceContext
,
T
>
{
public:
public:
...
@@ -163,11 +163,11 @@ using Tensor = framework::Tensor;
...
@@ -163,11 +163,11 @@ using Tensor = framework::Tensor;
using
LoDTensor
=
framework
::
LoDTensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
using
DataLayout
=
framework
::
DataLayout
;
using
DataLayout
=
framework
::
DataLayout
;
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template
<
typename
T
>
template
<
typename
T
>
class
LayerNormDirectCUDAFunctor
{
class
LayerNormDirectCUDAFunctor
{
public:
public:
void
operator
()(
cuda
Stream_t
stream
,
const
T
*
input
,
void
operator
()(
gpu
Stream_t
stream
,
const
T
*
input
,
std
::
vector
<
int
>
input_shape
,
const
T
*
bias
,
const
T
*
scale
,
std
::
vector
<
int
>
input_shape
,
const
T
*
bias
,
const
T
*
scale
,
T
*
output
,
T
*
mean
,
T
*
variance
,
int
begin_norm_axis
,
T
*
output
,
T
*
mean
,
T
*
variance
,
int
begin_norm_axis
,
float
eps
);
float
eps
);
...
...
paddle/fluid/operators/lod_tensor_to_array_op.cc
浏览文件 @
59940cb3
...
@@ -63,7 +63,7 @@ struct LoDTensorToArrayFunctor : public boost::static_visitor<void> {
...
@@ -63,7 +63,7 @@ struct LoDTensorToArrayFunctor : public boost::static_visitor<void> {
if
(
std
::
is_same
<
Place
,
platform
::
CPUPlace
>::
value
)
{
if
(
std
::
is_same
<
Place
,
platform
::
CPUPlace
>::
value
)
{
Apply
(
static_cast
<
platform
::
CPUDeviceContext
*>
(
dev_ctx
));
Apply
(
static_cast
<
platform
::
CPUDeviceContext
*>
(
dev_ctx
));
}
else
{
}
else
{
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Apply
(
static_cast
<
platform
::
CUDADeviceContext
*>
(
dev_ctx
));
Apply
(
static_cast
<
platform
::
CUDADeviceContext
*>
(
dev_ctx
));
#else
#else
PADDLE_THROW
(
PADDLE_THROW
(
...
...
paddle/fluid/operators/matmul_op.cc
浏览文件 @
59940cb3
...
@@ -76,7 +76,8 @@ class MatMulKernel : public framework::OpKernel<T> {
...
@@ -76,7 +76,8 @@ class MatMulKernel : public framework::OpKernel<T> {
auto
scale
=
static_cast
<
T
>
(
context
.
Attr
<
float
>
(
"alpha"
));
auto
scale
=
static_cast
<
T
>
(
context
.
Attr
<
float
>
(
"alpha"
));
int
head_number
=
1
;
int
head_number
=
1
;
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \
!defined(PADDLE_WITH_HIP)
head_number
=
context
.
Attr
<
int
>
(
"head_number"
);
head_number
=
context
.
Attr
<
int
>
(
"head_number"
);
#endif
#endif
...
@@ -89,7 +90,8 @@ class MatMulKernel : public framework::OpKernel<T> {
...
@@ -89,7 +90,8 @@ class MatMulKernel : public framework::OpKernel<T> {
mat_dim_a
.
batch_size_
=
0
;
mat_dim_a
.
batch_size_
=
0
;
}
}
}
}
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \
!defined(PADDLE_WITH_HIP)
bool
split_vertical_y
=
(
mat_dim_a
.
width_
!=
mat_dim_b
.
height_
);
bool
split_vertical_y
=
(
mat_dim_a
.
width_
!=
mat_dim_b
.
height_
);
if
(
head_number
>
1
)
{
if
(
head_number
>
1
)
{
...
@@ -228,7 +230,8 @@ class MatMulGradKernel : public framework::OpKernel<T> {
...
@@ -228,7 +230,8 @@ class MatMulGradKernel : public framework::OpKernel<T> {
auto
mat_dim_b
=
math
::
CreateMatrixDescriptor
(
b
.
dims
(),
0
,
trans_b
);
auto
mat_dim_b
=
math
::
CreateMatrixDescriptor
(
b
.
dims
(),
0
,
trans_b
);
int
head_number
=
1
;
int
head_number
=
1
;
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \
!defined(PADDLE_WITH_HIP)
head_number
=
context
.
Attr
<
int
>
(
"head_number"
);
head_number
=
context
.
Attr
<
int
>
(
"head_number"
);
#endif
#endif
...
@@ -362,7 +365,8 @@ class MatMulDoubleGradKernel : public framework::OpKernel<T> {
...
@@ -362,7 +365,8 @@ class MatMulDoubleGradKernel : public framework::OpKernel<T> {
auto
mat_dim_b
=
math
::
CreateMatrixDescriptor
(
b
.
dims
(),
0
,
trans_b
);
auto
mat_dim_b
=
math
::
CreateMatrixDescriptor
(
b
.
dims
(),
0
,
trans_b
);
int
head_number
=
1
;
int
head_number
=
1
;
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \
!defined(PADDLE_WITH_HIP)
head_number
=
context
.
Attr
<
int
>
(
"head_number"
);
head_number
=
context
.
Attr
<
int
>
(
"head_number"
);
#endif
#endif
...
@@ -562,7 +566,8 @@ class MatMulOp : public framework::OperatorWithKernel {
...
@@ -562,7 +566,8 @@ class MatMulOp : public framework::OperatorWithKernel {
DumpMatrixShape
(
mat_dim_y
).
c_str
()));
DumpMatrixShape
(
mat_dim_y
).
c_str
()));
}
}
int64_t
dim_out_y
=
mat_dim_y
.
width_
;
int64_t
dim_out_y
=
mat_dim_y
.
width_
;
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \
!defined(PADDLE_WITH_HIP)
int
head_number
=
context
->
Attrs
().
Get
<
int
>
(
"head_number"
);
int
head_number
=
context
->
Attrs
().
Get
<
int
>
(
"head_number"
);
bool
split_vertical_y
=
(
mat_dim_x
.
width_
!=
mat_dim_y
.
height_
);
bool
split_vertical_y
=
(
mat_dim_x
.
width_
!=
mat_dim_y
.
height_
);
if
(
context
->
IsRuntime
())
{
if
(
context
->
IsRuntime
())
{
...
@@ -750,7 +755,8 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -750,7 +755,8 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
"used in MKL-DNN INT8"
)
"used in MKL-DNN INT8"
)
.
SetDefault
(
false
);
.
SetDefault
(
false
);
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \
!defined(PADDLE_WITH_HIP)
AddAttr
<
int
>
(
"head_number"
,
"The number of heads of the matrix"
)
AddAttr
<
int
>
(
"head_number"
,
"The number of heads of the matrix"
)
.
SetDefault
(
1
);
.
SetDefault
(
1
);
#endif
#endif
...
@@ -916,7 +922,7 @@ REGISTER_OP_CPU_KERNEL(
...
@@ -916,7 +922,7 @@ REGISTER_OP_CPU_KERNEL(
ops
::
MatMulDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
MatMulDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
MatMulDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
ops
::
MatMulDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
matmul
,
ops
::
MatMulKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
matmul
,
ops
::
MatMulKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
MatMulKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
MatMulKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
...
...
paddle/fluid/operators/mean_op.cu
浏览文件 @
59940cb3
...
@@ -11,7 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS,
...
@@ -11,7 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#ifdef __NVCC__
#include "cub/cub.cuh"
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "paddle/fluid/operators/mean_op.h"
#include "paddle/fluid/operators/mean_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/float16.h"
...
...
paddle/fluid/operators/merge_lod_tensor_op.cc
浏览文件 @
59940cb3
...
@@ -65,7 +65,7 @@ class MergeLoDTensorOp : public framework::OperatorBase {
...
@@ -65,7 +65,7 @@ class MergeLoDTensorOp : public framework::OperatorBase {
if
(
platform
::
is_cpu_place
(
mask
.
place
()))
{
if
(
platform
::
is_cpu_place
(
mask
.
place
()))
{
cpu_mask
->
ShareDataWith
(
mask
);
cpu_mask
->
ShareDataWith
(
mask
);
}
else
if
(
platform
::
is_gpu_place
(
mask
.
place
()))
{
}
else
if
(
platform
::
is_gpu_place
(
mask
.
place
()))
{
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
framework
::
TensorCopy
(
mask
,
platform
::
CPUPlace
(),
dev_ctx
,
framework
::
TensorCopy
(
mask
,
platform
::
CPUPlace
(),
dev_ctx
,
cpu_mask
.
get
());
cpu_mask
.
get
());
#else
#else
...
...
paddle/fluid/operators/miopen_lstm_cache.h
0 → 100644
浏览文件 @
59940cb3
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/miopen_helper.h"
namespace
paddle
{
namespace
operators
{
class
ScopedRNNBase
{
public:
ScopedRNNBase
(
int
seq_length
,
int
batch_size
,
int
input_size
,
int
hidden_size
,
int
num_layers
,
float
dropout_prob
,
int
seed
,
int
weight_numel
,
bool
initialized
,
bool
is_bidirec
)
:
seq_length_
(
seq_length
),
batch_size_
(
batch_size
),
input_size_
(
input_size
),
hidden_size_
(
hidden_size
),
num_layers_
(
num_layers
),
dropout_prob_
(
dropout_prob
),
seed_
(
seed
),
weight_numel_
(
weight_numel
),
initialized_
(
initialized
),
is_bidirec_
(
is_bidirec
)
{}
template
<
typename
T
>
void
Create
(
const
miopenHandle_t
&
handle
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
int
>&
sequence_length
,
size_t
*
workspace_size
,
size_t
*
reserve_size
,
framework
::
Tensor
*
dropout_state
)
{
int
numDirections
=
is_bidirec_
?
2
:
1
;
miopenDataType_t
miopen_type
=
platform
::
CudnnDataType
<
T
>::
type
;
// ------------------- miopen x, y descriptors ---------------------
std
::
vector
<
int
>
dims_x
=
{
batch_size_
,
input_size_
,
1
};
std
::
vector
<
int
>
strides_x
=
{
input_size_
,
1
,
1
};
std
::
vector
<
int
>
dims_y
=
{
batch_size_
,
hidden_size_
*
numDirections
,
1
};
std
::
vector
<
int
>
strides_y
=
{
hidden_size_
*
numDirections
,
1
,
1
};
for
(
int
i
=
0
;
i
<
seq_length_
;
++
i
)
{
x_descs_
.
emplace_back
(
x_desc_
.
descriptor
<
T
>
(
dims_x
,
strides_x
));
y_descs_
.
emplace_back
(
y_desc_
.
descriptor
<
T
>
(
dims_y
,
strides_y
));
}
// ------------------- miopen hx, hy, cx, cy descriptors----------
std
::
vector
<
int
>
dims_hx
=
{
num_layers_
*
numDirections
,
batch_size_
,
hidden_size_
};
std
::
vector
<
int
>
strides_hx
=
{
hidden_size_
*
batch_size_
,
hidden_size_
,
1
};
init_h_desc_
.
descriptor
<
T
>
(
dims_hx
,
strides_hx
);
init_c_desc_
.
descriptor
<
T
>
(
dims_hx
,
strides_hx
);
last_h_desc_
.
descriptor
<
T
>
(
dims_hx
,
strides_hx
);
last_c_desc_
.
descriptor
<
T
>
(
dims_hx
,
strides_hx
);
// ------------------- miopen dropout descriptors ---------------------
size_t
state_size
;
if
(
!
initialized_
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDropoutGetStatesSize
(
handle
,
&
state_size
));
dropout_state
->
mutable_data
<
uint8_t
>
({
static_cast
<
int64_t
>
(
state_size
)},
place
);
}
dropout_desc_
.
descriptor
(
handle
,
place
,
initialized_
,
dropout_prob_
,
dropout_state
,
seed_
,
state_size
);
// ------------------- miopen rnn descriptors ---------------------
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetRNNDescriptor
(
rnn_desc_
.
desc
(),
hidden_size_
,
num_layers_
,
miopenRNNlinear
,
is_bidirec_
?
miopenRNNbidirection
:
miopenRNNunidirection
,
miopenLSTM
,
miopenRNNNoBias
,
miopenRNNdefault
,
miopen_type
));
// ------------------- miopen weights_size ---------------------
size_t
weights_size_
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenGetRNNParamsSize
(
handle
,
rnn_desc_
.
desc
(),
x_descs_
[
0
],
&
weights_size_
,
miopen_type
));
PADDLE_ENFORCE_EQ
(
weights_size_
,
sizeof
(
T
)
*
weight_numel_
,
platform
::
errors
::
InvalidArgument
(
"The miopen lstm and setting weight size should be same."
));
// ------------------- miopen weight descriptors ---------------------
platform
::
DataLayout
layout
=
platform
::
DataLayout
::
kNCHW
;
int
dim_tmp
=
weights_size_
/
sizeof
(
T
);
std
::
vector
<
int
>
dim_w
=
{
dim_tmp
,
1
,
1
};
weight_desc_
.
descriptor
<
T
>
(
layout
,
dim_w
);
// ------------------- miopen workspace, reserve size ---------------------
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenGetRNNWorkspaceSize
(
handle
,
rnn_desc_
.
desc
(),
seq_length_
,
x_descs_
.
data
(),
workspace_size
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenGetRNNTrainingReserveSize
(
handle
,
rnn_desc_
.
desc
(),
seq_length_
,
x_descs_
.
data
(),
reserve_size
));
}
miopenTensorDescriptor_t
*
x_descs
()
{
return
x_descs_
.
data
();
}
miopenTensorDescriptor_t
*
y_descs
()
{
return
y_descs_
.
data
();
}
miopenTensorDescriptor_t
init_h_desc
()
{
return
init_h_desc_
.
desc
();
}
miopenTensorDescriptor_t
init_c_desc
()
{
return
init_c_desc_
.
desc
();
}
miopenTensorDescriptor_t
last_h_desc
()
{
return
last_h_desc_
.
desc
();
}
miopenTensorDescriptor_t
last_c_desc
()
{
return
last_c_desc_
.
desc
();
}
miopenRNNDescriptor_t
rnn_desc
()
{
return
rnn_desc_
.
desc
();
}
miopenDropoutDescriptor_t
dropout_desc
()
{
return
dropout_desc_
.
desc
();
}
miopenTensorDescriptor_t
weight_desc
()
{
return
weight_desc_
.
desc
();
}
private:
int
seq_length_
;
int
batch_size_
;
int
input_size_
;
int
hidden_size_
;
int
num_layers_
;
float
dropout_prob_
;
int
seed_
;
int
weight_numel_
;
bool
initialized_
;
bool
is_bidirec_
;
std
::
vector
<
miopenTensorDescriptor_t
>
x_descs_
;
std
::
vector
<
miopenTensorDescriptor_t
>
y_descs_
;
platform
::
ScopedTensorDescriptor
x_desc_
;
platform
::
ScopedTensorDescriptor
y_desc_
;
platform
::
ScopedTensorDescriptor
init_h_desc_
;
platform
::
ScopedTensorDescriptor
init_c_desc_
;
platform
::
ScopedTensorDescriptor
last_h_desc_
;
platform
::
ScopedTensorDescriptor
last_c_desc_
;
platform
::
ScopedDropoutDescriptor
dropout_desc_
;
platform
::
ScopedFilterDescriptor
weight_desc_
;
platform
::
ScopedRNNDescriptor
rnn_desc_
;
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/miopen_rnn_cache.h
0 → 100644
浏览文件 @
59940cb3
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/miopen_helper.h"
namespace
paddle
{
namespace
operators
{
struct
CudnnRNNCache
{
CudnnRNNCache
()
{
x_desc_
=
NULL
;
y_desc_
=
NULL
;
}
~
CudnnRNNCache
()
{
release
();
}
miopenRNNDescriptor_t
rnn_desc_
;
miopenTensorDescriptor_t
*
x_desc_
;
miopenTensorDescriptor_t
*
y_desc_
;
miopenTensorDescriptor_t
hx_desc_
;
miopenTensorDescriptor_t
cx_desc_
;
miopenTensorDescriptor_t
hy_desc_
;
miopenTensorDescriptor_t
cy_desc_
;
miopenTensorDescriptor_t
dhx_desc_
;
miopenTensorDescriptor_t
dcx_desc_
;
miopenTensorDescriptor_t
dhy_desc_
;
miopenTensorDescriptor_t
dcy_desc_
;
miopenTensorDescriptor_t
output_x_desc_
;
miopenTensorDescriptor_t
output_y_desc_
;
miopenDropoutDescriptor_t
dropout_desc_
;
size_t
weights_size_
;
miopenTensorDescriptor_t
w_desc_
;
miopenTensorDescriptor_t
dw_desc_
;
size_t
workspace_size_
;
framework
::
Tensor
workspace_data_
;
size_t
seq_length_
;
float
dropout_prob_
;
bool
is_bidirec_
;
int
batch_size_
;
int
input_size_
;
int
hidden_size_
;
int
num_layers_
;
int
seed_
;
void
init
(
miopenHandle_t
handle
,
const
platform
::
Place
&
place
,
size_t
seq_len
,
int
batch_size
,
int
input_size
,
int
hidden_size
,
int
num_layers
,
float
dropout_prob
,
bool
is_bidirec
,
int
seed
,
int
weight_numel
,
size_t
*
reserve_size_
,
framework
::
Tensor
*
dropout_state_
,
bool
initialized
,
miopenDataType_t
miopen_type
)
{
seq_length_
=
seq_len
;
batch_size_
=
batch_size
;
input_size_
=
input_size
;
hidden_size_
=
hidden_size
;
num_layers_
=
num_layers
;
dropout_prob_
=
dropout_prob
;
is_bidirec_
=
is_bidirec
;
seed_
=
seed
;
const
auto
numDirections
=
is_bidirec_
?
2
:
1
;
PADDLE_ENFORCE_EQ
(
miopen_type
,
miopenFloat
,
platform
::
errors
::
InvalidArgument
(
"MIOPEN do not support double datatype."
));
auto
miopen_size
=
sizeof
(
float
);
x_desc_
=
new
miopenTensorDescriptor_t
[
seq_length_
];
y_desc_
=
new
miopenTensorDescriptor_t
[
seq_length_
];
std
::
vector
<
int
>
dims
=
{
batch_size_
,
input_size_
,
1
};
std
::
vector
<
int
>
strides
=
{
input_size_
,
1
,
1
};
std
::
vector
<
int
>
dims_y
=
{
batch_size_
,
hidden_size_
*
numDirections
,
1
};
std
::
vector
<
int
>
strides_y
=
{
hidden_size_
*
numDirections
,
1
,
1
};
for
(
size_t
i
=
0
;
i
<
seq_length_
;
++
i
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
x_desc_
[
i
]));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
y_desc_
[
i
]));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetTensorDescriptor
(
x_desc_
[
i
],
miopen_type
,
3
,
const_cast
<
int
*>
(
dims
.
data
()),
const_cast
<
int
*>
(
strides
.
data
())));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetTensorDescriptor
(
y_desc_
[
i
],
miopen_type
,
3
,
const_cast
<
int
*>
(
dims_y
.
data
()),
const_cast
<
int
*>
(
strides_y
.
data
())));
}
std
::
vector
<
int
>
dims_hx
=
{
num_layers_
*
numDirections
,
batch_size_
,
hidden_size_
};
std
::
vector
<
int
>
strides_hx
=
{
hidden_size_
*
batch_size_
,
hidden_size_
,
1
};
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
hx_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
cx_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
hy_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
cy_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
dhx_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
dcx_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
dhy_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
dcy_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetTensorDescriptor
(
hx_desc_
,
miopen_type
,
3
,
const_cast
<
int
*>
(
dims_hx
.
data
()),
const_cast
<
int
*>
(
strides_hx
.
data
())));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetTensorDescriptor
(
cx_desc_
,
miopen_type
,
3
,
const_cast
<
int
*>
(
dims_hx
.
data
()),
const_cast
<
int
*>
(
strides_hx
.
data
())));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetTensorDescriptor
(
hy_desc_
,
miopen_type
,
3
,
const_cast
<
int
*>
(
dims_hx
.
data
()),
const_cast
<
int
*>
(
strides_hx
.
data
())));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetTensorDescriptor
(
cy_desc_
,
miopen_type
,
3
,
const_cast
<
int
*>
(
dims_hx
.
data
()),
const_cast
<
int
*>
(
strides_hx
.
data
())));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetTensorDescriptor
(
dhx_desc_
,
miopen_type
,
3
,
const_cast
<
int
*>
(
dims_hx
.
data
()),
const_cast
<
int
*>
(
strides_hx
.
data
())));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetTensorDescriptor
(
dcx_desc_
,
miopen_type
,
3
,
const_cast
<
int
*>
(
dims_hx
.
data
()),
const_cast
<
int
*>
(
strides_hx
.
data
())));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetTensorDescriptor
(
dhy_desc_
,
miopen_type
,
3
,
const_cast
<
int
*>
(
dims_hx
.
data
()),
const_cast
<
int
*>
(
strides_hx
.
data
())));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetTensorDescriptor
(
dcy_desc_
,
miopen_type
,
3
,
const_cast
<
int
*>
(
dims_hx
.
data
()),
const_cast
<
int
*>
(
strides_hx
.
data
())));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateDropoutDescriptor
(
&
dropout_desc_
));
size_t
state_size
;
if
(
!
initialized
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDropoutGetStatesSize
(
handle
,
&
state_size
));
dropout_state_
->
Resize
({
static_cast
<
int64_t
>
(
state_size
)});
uint8_t
*
dropout_state_data
=
dropout_state_
->
mutable_data
<
uint8_t
>
(
place
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetDropoutDescriptor
(
dropout_desc_
,
handle
,
dropout_prob_
,
dropout_state_data
,
state_size
,
seed_
,
false
,
false
,
MIOPEN_RNG_PSEUDO_XORWOW
));
}
else
{
uint8_t
*
dropout_state_data
=
dropout_state_
->
data
<
uint8_t
>
();
auto
dropout_state_dims
=
dropout_state_
->
dims
();
state_size
=
dropout_state_dims
[
0
];
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenRestoreDropoutDescriptor
(
dropout_desc_
,
handle
,
dropout_prob_
,
dropout_state_data
,
state_size
,
0
,
false
,
false
,
MIOPEN_RNG_PSEUDO_XORWOW
));
}
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateRNNDescriptor
(
&
rnn_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetRNNDescriptor
(
rnn_desc_
,
hidden_size_
,
num_layers_
,
miopenRNNlinear
,
is_bidirec_
?
miopenRNNbidirection
:
miopenRNNunidirection
,
miopenLSTM
,
miopenRNNNoBias
,
miopenRNNdefault
,
miopen_type
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
w_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
dw_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenGetRNNParamsSize
(
handle
,
rnn_desc_
,
x_desc_
[
0
],
&
weights_size_
,
miopen_type
));
PADDLE_ENFORCE_EQ
(
weights_size_
,
miopen_size
*
weight_numel
,
platform
::
errors
::
InvalidArgument
(
"The miopen lstm and setting weight size should be same."
));
int
dim_w
[
3
];
dim_w
[
0
]
=
weights_size_
/
miopen_size
;
dim_w
[
1
]
=
1
;
dim_w
[
2
]
=
1
;
int
dim_s
[
2
];
dim_s
[
1
]
=
1
;
dim_s
[
0
]
=
dim_w
[
1
];
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetTensorDescriptor
(
w_desc_
,
miopen_type
,
3
,
dim_w
,
dim_s
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetTensorDescriptor
(
dw_desc_
,
miopen_type
,
3
,
dim_w
,
dim_s
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenGetRNNWorkspaceSize
(
handle
,
rnn_desc_
,
seq_length_
,
x_desc_
,
&
workspace_size_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenGetRNNTrainingReserveSize
(
handle
,
rnn_desc_
,
seq_length_
,
x_desc_
,
reserve_size_
));
workspace_data_
.
Resize
({
static_cast
<
int64_t
>
(
workspace_size_
)});
workspace_data_
.
mutable_data
<
uint8_t
>
(
place
);
}
void
release
()
{
for
(
size_t
i
=
0
;
i
<
seq_length_
;
++
i
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
x_desc_
[
i
]));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
y_desc_
[
i
]));
}
delete
[]
x_desc_
;
delete
[]
y_desc_
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
hx_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
cx_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
hy_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
cy_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
dhx_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
dcx_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
dhy_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
dcy_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyDropoutDescriptor
(
dropout_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyRNNDescriptor
(
rnn_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
w_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
dw_desc_
));
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/modified_huber_loss_op.h
浏览文件 @
59940cb3
...
@@ -29,8 +29,8 @@ using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
...
@@ -29,8 +29,8 @@ using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template
<
typename
T
>
template
<
typename
T
>
struct
CheckLabelValue
{
struct
CheckLabelValue
{
HOSTDEVICE
T
operator
()(
const
T
&
val
)
const
{
HOSTDEVICE
T
operator
()(
const
T
&
val
)
const
{
PADDLE_ENFORCE
(
PADDLE_ENFORCE
_EQ
(
val
==
static_cast
<
T
>
(
0
)
||
val
==
static_cast
<
T
>
(
1
),
val
==
static_cast
<
T
>
(
0
)
||
val
==
static_cast
<
T
>
(
1
),
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Input(label) value of modified_huber_loss_op expected to be 0 "
"Input(label) value of modified_huber_loss_op expected to be 0 "
"or 1, but got %ld. Please check label value."
,
"or 1, but got %ld. Please check label value."
,
...
...
paddle/fluid/operators/multinomial_op.cu
浏览文件 @
59940cb3
...
@@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#ifndef PADDLE_WITH_HIP
// To-do(qili93): fix this after issue resolved
// https://github.com/ROCmSoftwarePlatform/rocPRIM/issues/202
#include <thrust/execution_policy.h>
#include <thrust/execution_policy.h>
#include <thrust/random.h>
#include <thrust/random.h>
#include <thrust/scan.h>
#include <thrust/scan.h>
...
@@ -155,13 +159,24 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
...
@@ -155,13 +159,24 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
T
*
cpu_in_data
=
new
T
[
in_data_numel
];
T
*
cpu_in_data
=
new
T
[
in_data_numel
];
int64_t
*
cpu_out_data
=
new
int64_t
[
out_data_numel
];
int64_t
*
cpu_out_data
=
new
int64_t
[
out_data_numel
];
#ifdef PADDLE_WITH_HIP
hipMemcpy
(
cpu_in_data
,
in_data
,
in_data_numel
*
sizeof
(
T
),
hipMemcpyDeviceToHost
);
#else
cudaMemcpy
(
cpu_in_data
,
in_data
,
in_data_numel
*
sizeof
(
T
),
cudaMemcpy
(
cpu_in_data
,
in_data
,
in_data_numel
*
sizeof
(
T
),
cudaMemcpyDeviceToHost
);
cudaMemcpyDeviceToHost
);
#endif
MultinomialFunctor
<
T
>
(
cpu_out_data
,
cpu_in_data
,
num_samples
,
replacement
,
MultinomialFunctor
<
T
>
(
cpu_out_data
,
cpu_in_data
,
num_samples
,
replacement
,
num_categories
,
num_distributions
);
num_categories
,
num_distributions
);
#ifdef PADDLE_WITH_HIP
hipMemcpy
(
out_data
,
cpu_out_data
,
out_data_numel
*
sizeof
(
int64_t
),
hipMemcpyHostToDevice
);
#else
cudaMemcpy
(
out_data
,
cpu_out_data
,
out_data_numel
*
sizeof
(
int64_t
),
cudaMemcpy
(
out_data
,
cpu_out_data
,
out_data_numel
*
sizeof
(
int64_t
),
cudaMemcpyHostToDevice
);
cudaMemcpyHostToDevice
);
#endif
delete
[]
cpu_in_data
;
delete
[]
cpu_in_data
;
delete
[]
cpu_out_data
;
delete
[]
cpu_out_data
;
...
@@ -250,5 +265,7 @@ namespace ops = paddle::operators;
...
@@ -250,5 +265,7 @@ namespace ops = paddle::operators;
namespace
plat
=
paddle
::
platform
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
multinomial
,
ops
::
MultinomialOpKernel
<
plat
::
CUDADeviceContext
,
float
>
,
multinomial
,
ops
::
MultinomialOpKernel
<
plat
::
CUDADeviceContext
,
double
>
,
ops
::
MultinomialOpKernel
<
plat
::
CUDADeviceContext
,
double
>
);
ops
::
MultinomialOpKernel
<
plat
::
CUDADeviceContext
,
float
>
);
#endif
paddle/fluid/operators/nll_loss_op.cu
浏览文件 @
59940cb3
...
@@ -11,7 +11,6 @@ limitations under the License. */
...
@@ -11,7 +11,6 @@ limitations under the License. */
#include <algorithm>
#include <algorithm>
#include <functional>
#include <functional>
#include <string>
#include <string>
#include "cub/cub.cuh"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/nll_loss_op.h"
#include "paddle/fluid/operators/nll_loss_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/cuda_primitives.h"
...
@@ -361,7 +360,11 @@ class NLLLossCUDAKernel : public framework::OpKernel<T> {
...
@@ -361,7 +360,11 @@ class NLLLossCUDAKernel : public framework::OpKernel<T> {
auto
total_weight_data
=
total_weight
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
total_weight_data
=
total_weight
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
label_data
=
labels
->
data
<
int64_t
>
();
auto
label_data
=
labels
->
data
<
int64_t
>
();
auto
weight_data
=
weight
?
weight
->
data
<
T
>
()
:
nullptr
;
auto
weight_data
=
weight
?
weight
->
data
<
T
>
()
:
nullptr
;
#ifdef PADDLE_WITH_HIP
hipMemset
(
total_weight_data
,
0
,
sizeof
(
T
));
#else
cudaMemset
(
total_weight_data
,
0
,
sizeof
(
T
));
cudaMemset
(
total_weight_data
,
0
,
sizeof
(
T
));
#endif
auto
x_dims
=
x
->
dims
();
auto
x_dims
=
x
->
dims
();
auto
batch_size
=
x_dims
[
0
];
auto
batch_size
=
x_dims
[
0
];
auto
n_classes
=
x_dims
[
1
];
auto
n_classes
=
x_dims
[
1
];
...
@@ -429,7 +432,11 @@ class NLLLossGradCUDAKernel : public framework::OpKernel<T> {
...
@@ -429,7 +432,11 @@ class NLLLossGradCUDAKernel : public framework::OpKernel<T> {
auto
total_weight_data
=
total_weight
->
data
<
T
>
();
auto
total_weight_data
=
total_weight
->
data
<
T
>
();
auto
ignore_index
=
ctx
.
Attr
<
int64_t
>
(
"ignore_index"
);
auto
ignore_index
=
ctx
.
Attr
<
int64_t
>
(
"ignore_index"
);
auto
reduction
=
ctx
.
Attr
<
std
::
string
>
(
"reduction"
);
auto
reduction
=
ctx
.
Attr
<
std
::
string
>
(
"reduction"
);
#ifdef PADDLE_WITH_HIP
hipMemset
(
dx_data
,
0
,
dx
->
numel
()
*
sizeof
(
T
));
#else
cudaMemset
(
dx_data
,
0
,
dx
->
numel
()
*
sizeof
(
T
));
cudaMemset
(
dx_data
,
0
,
dx
->
numel
()
*
sizeof
(
T
));
#endif
int64_t
size_average
=
(
int64_t
)(
reduction
==
"mean"
);
int64_t
size_average
=
(
int64_t
)(
reduction
==
"mean"
);
auto
x_dims
=
x
->
dims
();
auto
x_dims
=
x
->
dims
();
...
...
paddle/fluid/operators/norm_op.cu
浏览文件 @
59940cb3
...
@@ -13,7 +13,13 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <algorithm>
#include <algorithm>
#ifdef __NVCC__
#include "cub/cub.cuh"
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "paddle/fluid/operators/norm_op.h"
#include "paddle/fluid/operators/norm_op.h"
namespace
paddle
{
namespace
paddle
{
...
...
paddle/fluid/operators/norm_utils.cu.h
浏览文件 @
59940cb3
...
@@ -17,10 +17,20 @@ limitations under the License. */
...
@@ -17,10 +17,20 @@ limitations under the License. */
#include <cfloat>
#include <cfloat>
#include <string>
#include <string>
#include <vector>
#include <vector>
#ifdef __NVCC__
#include "cub/cub.cuh"
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#else
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录