Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
946dbdae
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2297
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
946dbdae
编写于
3月 03, 2021
作者:
Q
Qi Li
提交者:
GitHub
3月 03, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[ROCM] update fluid operators for rocm (part6), test=develop (#31301)
上级
1cbccfa5
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
350 addition
and
36 deletion
+350
-36
paddle/fluid/operators/activation_cudnn.cu.cc
paddle/fluid/operators/activation_cudnn.cu.cc
+4
-0
paddle/fluid/operators/activation_cudnn_op.cu.cc
paddle/fluid/operators/activation_cudnn_op.cu.cc
+71
-9
paddle/fluid/operators/activation_op.cc
paddle/fluid/operators/activation_op.cc
+0
-3
paddle/fluid/operators/affine_channel_op.cu
paddle/fluid/operators/affine_channel_op.cu
+8
-0
paddle/fluid/operators/affine_grid_cudnn_op.cu.cc
paddle/fluid/operators/affine_grid_cudnn_op.cu.cc
+5
-0
paddle/fluid/operators/affine_grid_op.cc
paddle/fluid/operators/affine_grid_op.cc
+5
-2
paddle/fluid/operators/allclose_op.cu
paddle/fluid/operators/allclose_op.cu
+4
-1
paddle/fluid/operators/arg_min_max_op_base.cu.h
paddle/fluid/operators/arg_min_max_op_base.cu.h
+8
-2
paddle/fluid/operators/argsort_op.cu
paddle/fluid/operators/argsort_op.cu
+17
-1
paddle/fluid/operators/batch_fc_op.cu
paddle/fluid/operators/batch_fc_op.cu
+2
-3
paddle/fluid/operators/batch_norm_op.cu
paddle/fluid/operators/batch_norm_op.cu
+173
-6
paddle/fluid/operators/bce_loss_op.cu
paddle/fluid/operators/bce_loss_op.cu
+0
-1
paddle/fluid/operators/math/sequence_padding_test.cc
paddle/fluid/operators/math/sequence_padding_test.cc
+1
-1
paddle/fluid/operators/math/sequence_pooling_test.cc
paddle/fluid/operators/math/sequence_pooling_test.cc
+1
-1
paddle/fluid/operators/math/sequence_scale.cu
paddle/fluid/operators/math/sequence_scale.cu
+8
-0
paddle/fluid/operators/math/softmax.cu
paddle/fluid/operators/math/softmax.cu
+36
-3
paddle/fluid/operators/math/softmax.h
paddle/fluid/operators/math/softmax.h
+1
-1
paddle/fluid/operators/pool_op.h
paddle/fluid/operators/pool_op.h
+1
-1
python/paddle/fluid/tests/unittests/op_test.py
python/paddle/fluid/tests/unittests/op_test.py
+5
-1
未找到文件。
paddle/fluid/operators/activation_cudnn.cu.cc
浏览文件 @
946dbdae
...
...
@@ -14,7 +14,11 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_desc.h"
#else
#include "paddle/fluid/platform/cudnn_desc.h"
#endif
namespace
paddle
{
namespace
operators
{
...
...
paddle/fluid/operators/activation_cudnn_op.cu.cc
浏览文件 @
946dbdae
...
...
@@ -14,7 +14,11 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_desc.h"
#else
#include "paddle/fluid/platform/cudnn_desc.h"
#endif
namespace
paddle
{
namespace
platform
{
...
...
@@ -29,35 +33,71 @@ using platform::ActivationDescriptor;
using
platform
::
TensorDescriptor
;
using
platform
::
CUDADeviceContext
;
#ifdef PADDLE_WITH_HIP
#define GPUDNN_ACTIVATION_RELU miopenActivationRELU
#define GPUDNN_ACTIVATION_CLIPPED_RELU miopenActivationCLIPPEDRELU
#define GPUDNN_ACTIVATION_SIGMOID miopenActivationLOGISTIC
#define GPUDNN_ACTIVATION_TANH miopenActivationTANH
#else
#define GPUDNN_ACTIVATION_RELU CUDNN_ACTIVATION_RELU
#define GPUDNN_ACTIVATION_CLIPPED_RELU CUDNN_ACTIVATION_CLIPPED_RELU
#define GPUDNN_ACTIVATION_SIGMOID CUDNN_ACTIVATION_SIGMOID
#define GPUDNN_ACTIVATION_TANH CUDNN_ACTIVATION_TANH
#endif
template
<
typename
T
>
struct
CudnnActivationFunctor
{
using
ELEMENT_TYPE
=
T
;
#ifdef PADDLE_WITH_HIP
CudnnActivationFunctor
(
const
CUDADeviceContext
&
ctx
,
const
T
&
c
,
const
miopenActivationMode_t
&
m
)
:
ctx_
(
ctx
),
coef_
(
c
),
mode_
(
m
)
{}
#else
CudnnActivationFunctor
(
const
CUDADeviceContext
&
ctx
,
const
T
&
c
,
const
cudnnActivationMode_t
&
m
)
:
ctx_
(
ctx
),
coef_
(
c
),
mode_
(
m
)
{}
#endif
void
operator
()(
const
Tensor
&
x
,
Tensor
*
out
)
{
ActivationDescriptor
act_desc
;
act_desc
.
set
(
mode_
,
coef_
);
TensorDescriptor
x_desc
,
out_desc
;
x_desc
.
set
(
x
);
out_desc
.
set
(
GET_DATA_SAFELY
(
out
,
"Output"
,
"Out"
,
"CudnnActivation"
));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenActivationForward
(
ctx_
.
cudnn_handle
(),
act_desc
.
desc
(),
platform
::
CudnnDataType
<
T
>::
kOne
(),
x_desc
.
desc
(),
x
.
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
out_desc
.
desc
(),
out
->
mutable_data
<
T
>
(
ctx_
.
GetPlace
())));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnActivationForward
(
ctx_
.
cudnn_handle
(),
act_desc
.
desc
(),
platform
::
CudnnDataType
<
T
>::
kOne
(),
x_desc
.
desc
(),
x
.
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
out_desc
.
desc
(),
out
->
mutable_data
<
T
>
(
ctx_
.
GetPlace
())));
#endif
}
const
CUDADeviceContext
&
ctx_
;
const
T
coef_
;
#ifdef PADDLE_WITH_HIP
const
miopenActivationMode_t
mode_
;
#else
const
cudnnActivationMode_t
mode_
;
#endif
};
template
<
typename
T
>
struct
CudnnActivationGradFunctor
{
using
ELEMENT_TYPE
=
T
;
#ifdef PADDLE_WITH_HIP
CudnnActivationGradFunctor
(
const
CUDADeviceContext
&
ctx
,
const
T
&
c
,
const
miopenActivationMode_t
&
m
)
:
ctx_
(
ctx
),
coef_
(
c
),
mode_
(
m
)
{}
#else
CudnnActivationGradFunctor
(
const
CUDADeviceContext
&
ctx
,
const
T
&
c
,
const
cudnnActivationMode_t
&
m
)
:
ctx_
(
ctx
),
coef_
(
c
),
mode_
(
m
)
{}
#endif
void
operator
()(
const
Tensor
&
x
,
const
Tensor
&
out
,
const
Tensor
dout
,
Tensor
*
dx
)
{
ActivationDescriptor
act_desc
;
...
...
@@ -67,27 +107,40 @@ struct CudnnActivationGradFunctor {
out_desc
.
set
(
out
);
dout_desc
.
set
(
dout
);
dx_desc
.
set
(
GET_DATA_SAFELY
(
dx
,
"Output"
,
"X@GRAD"
,
"CudnnActivationGrad"
));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenActivationBackward
(
ctx_
.
cudnn_handle
(),
act_desc
.
desc
(),
platform
::
CudnnDataType
<
T
>::
kOne
(),
out_desc
.
desc
(),
out
.
data
<
T
>
(),
dout_desc
.
desc
(),
dout
.
data
<
T
>
(),
x_desc
.
desc
(),
x
.
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
dx_desc
.
desc
(),
dx
->
mutable_data
<
T
>
(
ctx_
.
GetPlace
())));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnActivationBackward
(
ctx_
.
cudnn_handle
(),
act_desc
.
desc
(),
platform
::
CudnnDataType
<
T
>::
kOne
(),
out_desc
.
desc
(),
out
.
data
<
T
>
(),
dout_desc
.
desc
(),
dout
.
data
<
T
>
(),
x_desc
.
desc
(),
x
.
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
dx_desc
.
desc
(),
dx
->
mutable_data
<
T
>
(
ctx_
.
GetPlace
())));
#endif
}
const
CUDADeviceContext
&
ctx_
;
const
T
coef_
;
#ifdef PADDLE_WITH_HIP
const
miopenActivationMode_t
mode_
;
#else
const
cudnnActivationMode_t
mode_
;
#endif
};
template
<
typename
T
>
struct
CudnnReluFunctor
:
public
CudnnActivationFunctor
<
T
>
{
explicit
CudnnReluFunctor
(
const
CUDADeviceContext
&
ctx
)
:
CudnnActivationFunctor
<
T
>
(
ctx
,
0.0
,
C
UDNN_ACTIVATION_RELU
)
{}
:
CudnnActivationFunctor
<
T
>
(
ctx
,
0.0
,
GP
UDNN_ACTIVATION_RELU
)
{}
};
template
<
typename
T
>
struct
CudnnReluGradFunctor
:
public
CudnnActivationGradFunctor
<
T
>
{
explicit
CudnnReluGradFunctor
(
const
CUDADeviceContext
&
ctx
)
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
0.0
,
C
UDNN_ACTIVATION_RELU
)
{}
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
0.0
,
GP
UDNN_ACTIVATION_RELU
)
{}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepOut
;
}
};
...
...
@@ -95,13 +148,13 @@ struct CudnnReluGradFunctor : public CudnnActivationGradFunctor<T> {
template
<
typename
T
>
struct
CudnnRelu6Functor
:
public
CudnnActivationFunctor
<
T
>
{
explicit
CudnnRelu6Functor
(
const
CUDADeviceContext
&
ctx
)
:
CudnnActivationFunctor
<
T
>
(
ctx
,
6.0
,
C
UDNN_ACTIVATION_CLIPPED_RELU
)
{}
:
CudnnActivationFunctor
<
T
>
(
ctx
,
6.0
,
GP
UDNN_ACTIVATION_CLIPPED_RELU
)
{}
};
template
<
typename
T
>
struct
CudnnRelu6GradFunctor
:
public
CudnnActivationGradFunctor
<
T
>
{
explicit
CudnnRelu6GradFunctor
(
const
CUDADeviceContext
&
ctx
)
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
6.0
,
CUDNN_ACTIVATION_CLIPPED_RELU
)
{
}
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
6.0
,
GPUDNN_ACTIVATION_CLIPPED_RELU
)
{
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepOut
;
}
};
...
...
@@ -109,12 +162,12 @@ struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor<T> {
template
<
typename
T
>
struct
CudnnSigmoidFunctor
:
public
CudnnActivationFunctor
<
T
>
{
explicit
CudnnSigmoidFunctor
(
const
CUDADeviceContext
&
ctx
)
:
CudnnActivationFunctor
<
T
>
(
ctx
,
0.0
,
C
UDNN_ACTIVATION_SIGMOID
)
{}
:
CudnnActivationFunctor
<
T
>
(
ctx
,
0.0
,
GP
UDNN_ACTIVATION_SIGMOID
)
{}
};
template
<
typename
T
>
struct
CudnnSigmoidGradFunctor
:
public
CudnnActivationGradFunctor
<
T
>
{
explicit
CudnnSigmoidGradFunctor
(
const
CUDADeviceContext
&
ctx
)
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
0.0
,
C
UDNN_ACTIVATION_SIGMOID
)
{}
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
0.0
,
GP
UDNN_ACTIVATION_SIGMOID
)
{}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepOut
;
}
};
...
...
@@ -122,12 +175,12 @@ struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor<T> {
template
<
typename
T
>
struct
CudnnTanhFunctor
:
public
CudnnActivationFunctor
<
T
>
{
explicit
CudnnTanhFunctor
(
const
CUDADeviceContext
&
ctx
)
:
CudnnActivationFunctor
<
T
>
(
ctx
,
0.0
,
C
UDNN_ACTIVATION_TANH
)
{}
:
CudnnActivationFunctor
<
T
>
(
ctx
,
0.0
,
GP
UDNN_ACTIVATION_TANH
)
{}
};
template
<
typename
T
>
struct
CudnnTanhGradFunctor
:
public
CudnnActivationGradFunctor
<
T
>
{
explicit
CudnnTanhGradFunctor
(
const
CUDADeviceContext
&
ctx
)
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
0.0
,
C
UDNN_ACTIVATION_TANH
)
{}
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
0.0
,
GP
UDNN_ACTIVATION_TANH
)
{}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepOut
;
}
};
...
...
@@ -183,6 +236,14 @@ namespace ops = paddle::operators;
__macro(sigmoid, CudnnSigmoidFunctor, CudnnSigmoidGradFunctor); \
__macro(tanh, CudnnTanhFunctor, CudnnTanhGradFunctor)
#ifdef PADDLE_WITH_HIP
#define REGISTER_ACTIVATION_CUDNN_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_KERNEL(act_type, CUDNN, plat::CUDAPlace, \
ops::CudnnActivationKernel<ops::functor<float>>); \
REGISTER_OP_KERNEL( \
act_type##_grad, CUDNN, plat::CUDAPlace, \
ops::CudnnActivationGradKernel<ops::grad_functor<float>>);
#else
#define REGISTER_ACTIVATION_CUDNN_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_KERNEL(act_type, CUDNN, plat::CUDAPlace, \
ops::CudnnActivationKernel<ops::functor<float>>, \
...
...
@@ -191,5 +252,6 @@ namespace ops = paddle::operators;
act_type##_grad, CUDNN, plat::CUDAPlace, \
ops::CudnnActivationGradKernel<ops::grad_functor<float>>, \
ops::CudnnActivationGradKernel<ops::grad_functor<double>>);
#endif
FOR_EACH_CUDNN_OP_FUNCTOR
(
REGISTER_ACTIVATION_CUDNN_KERNEL
);
paddle/fluid/operators/activation_op.cc
浏览文件 @
946dbdae
...
...
@@ -24,9 +24,6 @@ limitations under the License. */
#include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
#include "paddle/fluid/platform/port.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
DECLARE_bool
(
use_mkldnn
);
...
...
paddle/fluid/operators/affine_channel_op.cu
浏览文件 @
946dbdae
...
...
@@ -12,7 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cuda_primitives.h"
...
...
paddle/fluid/operators/affine_grid_cudnn_op.cu.cc
浏览文件 @
946dbdae
...
...
@@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_WITH_HIP
// HIP not support cudnnSpatialTfGridGeneratorForward
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cudnn_helper.h"
...
...
@@ -121,3 +124,5 @@ REGISTER_OP_KERNEL(affine_grid, CUDNN, plat::CUDAPlace,
REGISTER_OP_KERNEL
(
affine_grid_grad
,
CUDNN
,
plat
::
CUDAPlace
,
paddle
::
operators
::
CUDNNAffineGridGradOpKernel
<
float
>
,
paddle
::
operators
::
CUDNNAffineGridGradOpKernel
<
double
>
);
#endif // not PADDLE_WITH_HIP
paddle/fluid/operators/affine_grid_op.cc
浏览文件 @
946dbdae
...
...
@@ -21,6 +21,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
...
...
@@ -109,7 +112,7 @@ class AffineGridOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library
{
framework
::
LibraryType
::
kPlain
};
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library
=
framework
::
LibraryType
::
kCUDNN
;
}
...
...
@@ -226,7 +229,7 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
...
...
paddle/fluid/operators/allclose_op.cu
浏览文件 @
946dbdae
...
...
@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/allclose_op.h"
...
...
@@ -67,7 +66,11 @@ struct AllcloseFunctor<platform::CUDADeviceContext, T> {
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
grid
=
(
grid
>
block
)
?
block
:
grid
;
#ifdef PADDLE_WITH_HIP
hipMemset
(
out_data
,
true
,
sizeof
(
bool
));
#else
cudaMemset
(
out_data
,
true
,
sizeof
(
bool
));
#endif
AllcloseCUDAKernel
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
other_data
,
rtol
,
atol
,
equal_nan
,
num
,
out_data
);
}
...
...
paddle/fluid/operators/arg_min_max_op_base.cu.h
浏览文件 @
946dbdae
...
...
@@ -14,9 +14,15 @@ limitations under the License. */
#pragma once
#if
def __NVCC__
#if
defined(__NVCC__) || defined(__HIPCC__)
#include <cub/cub.cuh>
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include <limits>
#include <string>
#include <typeinfo>
...
...
paddle/fluid/operators/argsort_op.cu
浏览文件 @
946dbdae
...
...
@@ -16,13 +16,28 @@ limitations under the License. */
#include <thrust/execution_policy.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/argsort_op.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#ifdef __HIPCC__
namespace
rocprim
{
namespace
detail
{
template
<
>
struct
radix_key_codec_base
<
paddle
::
platform
::
float16
>
:
radix_key_codec_integral
<
paddle
::
platform
::
float16
,
uint16_t
>
{};
}
// namespace detail
}
// namespace rocprim
#else
// set cub base traits in order to handle float16
namespace
cub
{
template
<
>
...
...
@@ -30,6 +45,7 @@ struct NumericTraits<paddle::platform::float16>
:
BaseTraits
<
FLOATING_POINT
,
true
,
false
,
uint16_t
,
paddle
::
platform
::
float16
>
{};
}
// namespace cub
#endif
namespace
paddle
{
namespace
operators
{
...
...
@@ -139,7 +155,7 @@ void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input,
cub
::
CountingInputIterator
<
IndType
>>
segment_offsets_t
(
counting_iter
,
SegmentOffsetIter
(
num_cols
));
cuda
Error_t
err
;
gpu
Error_t
err
;
if
(
descending
)
{
err
=
cub
::
DeviceSegmentedRadixSort
::
SortPairsDescending
(
nullptr
,
temp_storage_bytes
,
inp
,
sorted_out_ptr
,
...
...
paddle/fluid/operators/batch_fc_op.cu
浏览文件 @
946dbdae
...
...
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cublas.h>
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/batch_fc_op.h"
...
...
@@ -42,7 +41,7 @@ __global__ void add_bias_kernel(T* data, int slot_pairs_num, int ins_num,
}
template
<
typename
T
>
void
add_bias
(
cuda
Stream_t
stream
,
T
*
data
,
int
slot_pairs_num
,
int
ins_num
,
void
add_bias
(
gpu
Stream_t
stream
,
T
*
data
,
int
slot_pairs_num
,
int
ins_num
,
int
out_dim
,
const
T
*
bias
)
{
add_bias_kernel
<<<
GET_BLOCKS
(
slot_pairs_num
*
ins_num
*
out_dim
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
data
,
slot_pairs_num
,
...
...
@@ -65,7 +64,7 @@ __global__ void add_bias_grad_kernel(const T* dout_data, int slot_pairs_num,
}
template
<
typename
T
>
void
add_bias_grad
(
cuda
Stream_t
stream
,
const
T
*
dout_data
,
int
slot_pairs_num
,
void
add_bias_grad
(
gpu
Stream_t
stream
,
const
T
*
dout_data
,
int
slot_pairs_num
,
int
ins_num
,
int
out_dim
,
T
*
db_data
)
{
add_bias_grad_kernel
<<<
GET_BLOCKS
(
slot_pairs_num
*
out_dim
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
dout_data
,
slot_pairs_num
,
ins_num
,
...
...
paddle/fluid/operators/batch_norm_op.cu
浏览文件 @
946dbdae
...
...
@@ -16,12 +16,17 @@ limitations under the License. */
#include <cfloat>
#include <string>
#include <vector>
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/norm_utils.cu.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
DECLARE_bool
(
cudnn_batchnorm_spatial_persistent
);
...
...
@@ -73,6 +78,11 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
ExtractNCWHD
(
x_dims
,
data_layout
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
auto
dtype
=
platform
::
CudnnDataType
<
T
>::
type
;
#ifdef PADDLE_WITH_HIP
// HIP do not support compute format of NHWC
auto
compute_format
=
DataLayout
::
kNCHW
;
#else
const
bool
fast_nhwc_batch_norm
=
test_mode
||
(
dtype
==
CUDNN_DATA_HALF
&&
FLAGS_cudnn_batchnorm_spatial_persistent
);
...
...
@@ -81,6 +91,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
fast_nhwc_batch_norm
&&
data_layout
==
DataLayout
::
kNHWC
?
DataLayout
::
kNHWC
:
DataLayout
::
kNCHW
;
#endif
Tensor
transformed_x
(
x
->
type
());
Tensor
transformed_y
(
y
->
type
());
...
...
@@ -98,7 +109,17 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
transformed_y
.
ShareDataWith
(
*
y
);
}
// ------------------- cudnn descriptors ---------------------
// ------------------- cudnn descriptors ---------------------
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t
data_desc_
;
miopenTensorDescriptor_t
bn_param_desc_
;
miopenBatchNormMode_t
mode_
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
bn_param_desc_
));
#else
cudnnTensorDescriptor_t
data_desc_
;
cudnnTensorDescriptor_t
bn_param_desc_
;
cudnnBatchNormMode_t
mode_
;
...
...
@@ -107,6 +128,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
bn_param_desc_
));
#endif
if
(
epsilon
<=
CUDNN_BN_MIN_EPSILON
-
FLT_EPSILON
)
{
LOG
(
ERROR
)
<<
"Provided epsilon is smaller than "
...
...
@@ -114,7 +136,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
<<
"CUDNN_BN_MIN_EPSILON instead."
;
}
epsilon
=
std
::
max
(
epsilon
,
CUDNN_BN_MIN_EPSILON
);
#if CUDNN_VERSION_MIN(7, 0, 1)
#ifdef PADDLE_WITH_HIP
mode_
=
miopenBNSpatial
;
#elif CUDNN_VERSION_MIN(7, 0, 1)
if
(
FLAGS_cudnn_batchnorm_spatial_persistent
)
{
mode_
=
CUDNN_BATCHNORM_SPATIAL_PERSISTENT
;
}
else
{
...
...
@@ -134,6 +159,17 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
dims
=
{
N
,
C
,
H
,
W
,
D
};
strides
=
{
H
*
W
*
D
*
C
,
1
,
W
*
D
*
C
,
D
*
C
,
C
};
}
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetTensorDescriptor
(
data_desc_
,
CudnnDataType
<
T
>::
type
,
x_dims
.
size
()
>
3
?
x_dims
.
size
()
:
4
,
const_cast
<
int
*>
(
dims
.
data
()),
const_cast
<
int
*>
(
strides
.
data
())));
// Note: PERSISTENT not implemented for inference
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDeriveBNTensorDescriptor
(
bn_param_desc_
,
data_desc_
,
test_mode
?
miopenBNSpatial
:
mode_
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
data_desc_
,
CudnnDataType
<
T
>::
type
,
x_dims
.
size
()
>
3
?
x_dims
.
size
()
:
4
,
dims
.
data
(),
strides
.
data
()));
...
...
@@ -142,6 +178,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
platform
::
dynload
::
cudnnDeriveBNTensorDescriptor
(
bn_param_desc_
,
data_desc_
,
test_mode
?
CUDNN_BATCHNORM_SPATIAL
:
mode_
));
#endif
const
auto
*
scale
=
ctx
.
Input
<
Tensor
>
(
"Scale"
);
const
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
...
...
@@ -188,6 +225,30 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
"variance is [%d], the dimensions of variance is [%s]."
,
C
,
est_var
->
dims
()[
0
],
est_var
->
dims
()));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenBatchNormalizationForwardInference
(
handle
,
miopenBNSpatial
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
CudnnDataType
<
T
>::
kOne
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
CudnnDataType
<
T
>::
kZero
())),
data_desc_
,
static_cast
<
const
void
*>
(
transformed_x
.
template
data
<
T
>()),
data_desc_
,
static_cast
<
void
*>
(
transformed_y
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
())),
bn_param_desc_
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
scale
->
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
bias
->
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
est_var
->
template
data
<
BatchNormParamType
<
T
>
>
())),
epsilon
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnBatchNormalizationForwardInference
(
handle
,
...
...
@@ -200,6 +261,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
bias
->
template
data
<
BatchNormParamType
<
T
>
>
(),
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
(),
est_var
->
template
data
<
BatchNormParamType
<
T
>
>
(),
epsilon
));
#endif
}
else
{
// if MomentumTensor is set, use MomentumTensor value, momentum
// is only used in this training branch
...
...
@@ -302,6 +364,36 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
reserve_space_size
));
#endif // CUDNN_VERSION_MIN(7, 4, 1)
if
(
!
called
)
{
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenBatchNormalizationForwardTraining
(
handle
,
mode_
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
CudnnDataType
<
T
>::
kOne
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
CudnnDataType
<
T
>::
kZero
())),
data_desc_
,
static_cast
<
const
void
*>
(
transformed_x
.
template
data
<
T
>()),
data_desc_
,
static_cast
<
void
*>
(
transformed_y
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
())),
bn_param_desc_
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
scale
->
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
bias
->
template
data
<
BatchNormParamType
<
T
>
>
())),
this_factor
,
static_cast
<
void
*>
(
mean_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())),
static_cast
<
void
*>
(
variance_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())),
epsilon
,
static_cast
<
void
*>
(
saved_mean
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())),
static_cast
<
void
*>
(
saved_variance
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()))));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnBatchNormalizationForwardTraining
(
handle
,
mode_
,
CudnnDataType
<
T
>::
kOne
(),
...
...
@@ -319,6 +411,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
ctx
.
GetPlace
()),
saved_variance
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())));
#endif
}
}
}
...
...
@@ -329,11 +422,19 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
TransToChannelLast
<
paddle
::
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
&
transformed_y
,
y
);
}
#ifdef PADDLE_WITH_HIP
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
bn_param_desc_
));
#else
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
bn_param_desc_
));
#endif
}
};
...
...
@@ -416,7 +517,7 @@ class InplaceHelper {
const
BatchNormParamType
<
T
>
*
mean
,
const
BatchNormParamType
<
T
>
*
variance
,
double
epsilon
,
int
C
,
int
M
,
const
int
num
,
const
T
*
y
,
int
grid2
,
const
int
block
,
const
cuda
Stream_t
&
stream
)
{
const
gpu
Stream_t
&
stream
)
{
PADDLE_ENFORCE_EQ
(
x
,
y
,
platform
::
errors
::
InvalidArgument
(
"X and Y should be inplaced in inplace mode"
));
KeBNRestoreData
<<<
grid2
,
block
,
0
,
stream
>>>
(
...
...
@@ -566,6 +667,10 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
auto
dtype
=
platform
::
CudnnDataType
<
T
>::
type
;
const
auto
*
reserve_space
=
ctx
.
Input
<
Tensor
>
(
"ReserveSpace"
);
#ifdef PADDLE_WITH_HIP
// HIP do not support compute format of NHWC
auto
compute_format
=
DataLayout
::
kNCHW
;
#else
const
bool
fast_nhwc_batch_norm
=
dtype
==
CUDNN_DATA_HALF
&&
FLAGS_cudnn_batchnorm_spatial_persistent
&&
reserve_space
!=
nullptr
;
...
...
@@ -573,6 +678,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
fast_nhwc_batch_norm
&&
data_layout
==
DataLayout
::
kNHWC
?
DataLayout
::
kNHWC
:
DataLayout
::
kNCHW
;
#endif
Tensor
transformed_x
(
x
->
type
());
Tensor
transformed_d_y
(
d_y
->
type
());
...
...
@@ -626,7 +732,17 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
return
;
}
// ------------------- cudnn descriptors ---------------------
// ------------------- cudnn descriptors ---------------------
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t
data_desc_
;
miopenTensorDescriptor_t
bn_param_desc_
;
miopenBatchNormMode_t
mode_
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
bn_param_desc_
));
#else
cudnnTensorDescriptor_t
data_desc_
;
cudnnTensorDescriptor_t
bn_param_desc_
;
cudnnBatchNormMode_t
mode_
;
...
...
@@ -635,13 +751,16 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
bn_param_desc_
));
#endif
if
(
epsilon
<=
CUDNN_BN_MIN_EPSILON
-
FLT_EPSILON
)
{
LOG
(
ERROR
)
<<
"Provided epsilon is smaller than "
<<
"CUDNN_BN_MIN_EPSILON. Setting it to "
<<
"CUDNN_BN_MIN_EPSILON instead."
;
}
epsilon
=
std
::
max
(
epsilon
,
CUDNN_BN_MIN_EPSILON
);
#if CUDNN_VERSION_MIN(7, 0, 1)
#ifdef PADDLE_WITH_HIP
mode_
=
miopenBNSpatial
;
#elif CUDNN_VERSION_MIN(7, 0, 1)
if
(
FLAGS_cudnn_batchnorm_spatial_persistent
)
{
mode_
=
CUDNN_BATCHNORM_SPATIAL_PERSISTENT
;
}
else
{
...
...
@@ -651,12 +770,22 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
mode_
=
CUDNN_BATCHNORM_SPATIAL
;
#endif // CUDNN_VERSION_MIN(7, 0, 1)
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetTensorDescriptor
(
data_desc_
,
CudnnDataType
<
T
>::
type
,
x_dims
.
size
()
>
3
?
x_dims
.
size
()
:
4
,
const_cast
<
int
*>
(
dims
.
data
()),
const_cast
<
int
*>
(
strides
.
data
())));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDeriveBNTensorDescriptor
(
bn_param_desc_
,
data_desc_
,
mode_
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
data_desc_
,
CudnnDataType
<
T
>::
type
,
x_dims
.
size
()
>
3
?
x_dims
.
size
()
:
4
,
dims
.
data
(),
strides
.
data
()));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnDeriveBNTensorDescriptor
(
bn_param_desc_
,
data_desc_
,
mode_
));
#endif
const
auto
*
saved_mean
=
ctx
.
Input
<
Tensor
>
(
"SavedMean"
);
const
auto
*
saved_var
=
ctx
.
Input
<
Tensor
>
(
"SavedVariance"
);
...
...
@@ -741,6 +870,22 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
/*reserveSpaceSizeInBytes=*/
reserve_space_size
));
#endif // CUDNN_VERSION_MIN(7, 4, 1)
if
(
!
called
)
{
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenBatchNormalizationBackward
(
dev_ctx
.
cudnn_handle
(),
mode_
,
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
transformed_x
.
template
data
<
T
>(),
data_desc_
,
transformed_d_y
.
template
data
<
T
>(),
data_desc_
,
transformed_d_x
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
bn_param_desc_
,
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
d_scale
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
d_bias
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
epsilon
,
saved_mean_data
,
saved_var_data
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnBatchNormalizationBackward
(
dev_ctx
.
cudnn_handle
(),
mode_
,
CudnnDataType
<
T
>::
kOne
(),
...
...
@@ -755,6 +900,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
d_bias
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
epsilon
,
saved_mean_data
,
saved_var_data
));
#endif
}
if
(
data_layout
==
DataLayout
::
kNHWC
&&
...
...
@@ -784,11 +930,19 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
}
}
#ifdef PADDLE_WITH_HIP
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
bn_param_desc_
));
#else
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
bn_param_desc_
));
#endif
}
else
{
const
auto
*
running_mean
=
ctx
.
Input
<
Tensor
>
(
"Mean"
);
const
auto
*
running_var
=
ctx
.
Input
<
Tensor
>
(
"Variance"
);
...
...
@@ -886,6 +1040,18 @@ class BatchNormDoubleGradKernel<platform::CUDADeviceContext, T>
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL
(
batch_norm
,
ops
::
BatchNormKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
BatchNormKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
batch_norm_grad
,
ops
::
BatchNormGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
BatchNormGradKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
batch_norm_grad_grad
,
ops
::
BatchNormDoubleGradKernel
<
plat
::
CUDADeviceContext
,
float
>
);
#else
REGISTER_OP_CUDA_KERNEL
(
batch_norm
,
ops
::
BatchNormKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
BatchNormKernel
<
plat
::
CUDADeviceContext
,
double
>
,
...
...
@@ -898,3 +1064,4 @@ REGISTER_OP_CUDA_KERNEL(
batch_norm_grad_grad
,
ops
::
BatchNormDoubleGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
BatchNormDoubleGradKernel
<
plat
::
CUDADeviceContext
,
double
>
);
#endif
paddle/fluid/operators/bce_loss_op.cu
浏览文件 @
946dbdae
...
...
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "cub/cub.cuh"
#include "paddle/fluid/operators/bce_loss_op.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/cuda_primitives.h"
...
...
paddle/fluid/operators/math/sequence_padding_test.cc
浏览文件 @
946dbdae
...
...
@@ -105,7 +105,7 @@ TEST(Seq2BatchPadding, CPU) {
128
);
}
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST
(
SequencePadding
,
CUDA
)
{
auto
place
=
paddle
::
platform
::
CUDAPlace
(
0
);
auto
*
context
=
static_cast
<
paddle
::
platform
::
CUDADeviceContext
*>
(
...
...
paddle/fluid/operators/math/sequence_pooling_test.cc
浏览文件 @
946dbdae
...
...
@@ -123,7 +123,7 @@ TEST(SequencePoolingGrad, CPU_SUM) {
lod2
,
128
);
}
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST
(
SequencePoolingGrad
,
CUDA_SUM
)
{
auto
place
=
paddle
::
platform
::
CUDAPlace
(
0
);
auto
*
context
=
static_cast
<
paddle
::
platform
::
CUDADeviceContext
*>
(
...
...
paddle/fluid/operators/math/sequence_scale.cu
浏览文件 @
946dbdae
...
...
@@ -44,10 +44,18 @@ class ScaleLoDTensorFunctor<platform::CUDADeviceContext, T> {
framework
::
LoD
abs_offset_lod
=
framework
::
ToAbsOffset
(
lod
);
T
*
seq_data
=
seq
->
mutable_data
<
T
>
(
context
.
GetPlace
());
#ifdef PADDLE_WITH_HIP
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
SequenceScaleKernel
<
T
,
PADDLE_CUDA_NUM_THREADS
>
),
dim3
(
num_seq
),
dim3
(
PADDLE_CUDA_NUM_THREADS
),
0
,
context
.
stream
(),
seq_data
,
abs_offset_lod
[
level
].
CUDAMutableData
(
context
.
GetPlace
()),
scales
,
seq_width
);
#else
SequenceScaleKernel
<
T
,
PADDLE_CUDA_NUM_THREADS
><<<
num_seq
,
PADDLE_CUDA_NUM_THREADS
,
0
,
context
.
stream
()
>>>
(
seq_data
,
abs_offset_lod
[
level
].
CUDAMutableData
(
context
.
GetPlace
()),
scales
,
seq_width
);
#endif
}
};
...
...
paddle/fluid/operators/math/softmax.cu
浏览文件 @
946dbdae
...
...
@@ -16,7 +16,11 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/math/softmax_impl.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#else
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
...
...
@@ -45,6 +49,16 @@ void SoftmaxCUDNNFunctor<T>::operator()(
if
(
cudnn_tensor_dims
.
size
()
<=
2
)
{
cudnn_tensor_dims
.
resize
(
4
,
1
);
}
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t
cudnn_x_desc
=
xDesc
.
descriptor
<
T
>
(
layout
,
cudnn_tensor_dims
);
miopenTensorDescriptor_t
cudnn_y_desc
=
xDesc
.
descriptor
<
T
>
(
layout
,
cudnn_tensor_dims
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxForward
(
context
.
cudnn_handle
(),
CudnnDataType
<
T
>::
kOne
(),
cudnn_x_desc
,
X
->
data
<
T
>
(),
CudnnDataType
<
T
>::
kZero
(),
cudnn_y_desc
,
Y
->
mutable_data
<
T
>
(
context
.
GetPlace
())));
#else
cudnnTensorDescriptor_t
cudnn_x_desc
=
xDesc
.
descriptor
<
T
>
(
layout
,
cudnn_tensor_dims
);
cudnnTensorDescriptor_t
cudnn_y_desc
=
...
...
@@ -54,6 +68,7 @@ void SoftmaxCUDNNFunctor<T>::operator()(
CUDNN_SOFTMAX_MODE_INSTANCE
,
CudnnDataType
<
T
>::
kOne
(),
cudnn_x_desc
,
X
->
data
<
T
>
(),
CudnnDataType
<
T
>::
kZero
(),
cudnn_y_desc
,
Y
->
mutable_data
<
T
>
(
context
.
GetPlace
())));
#endif
}
template
<
typename
T
>
...
...
@@ -74,6 +89,19 @@ void SoftmaxGradCUDNNFunctor<T>::operator()(
if
(
cudnn_tensor_dims
.
size
()
<=
2
)
{
cudnn_tensor_dims
.
resize
(
4
,
1
);
}
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t
cudnn_y_desc
=
yDesc
.
descriptor
<
T
>
(
layout
,
cudnn_tensor_dims
);
miopenTensorDescriptor_t
cudnn_xgrad_desc
=
dxDesc
.
descriptor
<
T
>
(
layout
,
cudnn_tensor_dims
);
miopenTensorDescriptor_t
cudnn_ygrad_desc
=
dyDesc
.
descriptor
<
T
>
(
layout
,
cudnn_tensor_dims
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxBackward
(
context
.
cudnn_handle
(),
CudnnDataType
<
T
>::
kOne
(),
cudnn_y_desc
,
Y
->
data
<
T
>
(),
cudnn_ygrad_desc
,
YGrad
->
data
<
T
>
(),
CudnnDataType
<
T
>::
kZero
(),
cudnn_xgrad_desc
,
XGrad
->
mutable_data
<
T
>
(
context
.
GetPlace
())));
#else
cudnnTensorDescriptor_t
cudnn_y_desc
=
yDesc
.
descriptor
<
T
>
(
layout
,
cudnn_tensor_dims
);
cudnnTensorDescriptor_t
cudnn_xgrad_desc
=
...
...
@@ -86,15 +114,20 @@ void SoftmaxGradCUDNNFunctor<T>::operator()(
Y
->
data
<
T
>
(),
cudnn_ygrad_desc
,
YGrad
->
data
<
T
>
(),
CudnnDataType
<
T
>::
kZero
(),
cudnn_xgrad_desc
,
XGrad
->
mutable_data
<
T
>
(
context
.
GetPlace
())));
#endif
}
template
class
SoftmaxCUDNNFunctor
<
platform
::
float16
>;
template
class
SoftmaxCUDNNFunctor
<
float
>;
template
class
SoftmaxCUDNNFunctor
<
double
>;
template
class
SoftmaxCUDNNFunctor
<
platform
::
float16
>;
template
class
SoftmaxGradCUDNNFunctor
<
float
>;
template
class
SoftmaxGradCUDNNFunctor
<
double
>;
template
class
SoftmaxGradCUDNNFunctor
<
platform
::
float16
>;
// MIOPEN do not support double
#ifndef PADDLE_WITH_HIP
template
class
SoftmaxCUDNNFunctor
<
double
>;
template
class
SoftmaxGradCUDNNFunctor
<
double
>;
#endif
template
class
SoftmaxFunctor
<
platform
::
CUDADeviceContext
,
platform
::
float16
,
false
>;
template
class
SoftmaxFunctor
<
platform
::
CUDADeviceContext
,
platform
::
float16
,
...
...
paddle/fluid/operators/math/softmax.h
浏览文件 @
946dbdae
...
...
@@ -35,7 +35,7 @@ class SoftmaxGradFunctor {
framework
::
Tensor
*
x_grad
);
};
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template
<
typename
T
>
class
SoftmaxCUDNNFunctor
{
public:
...
...
paddle/fluid/operators/pool_op.h
浏览文件 @
946dbdae
...
...
@@ -22,7 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
#if
def __NVCC__
#if
defined(__HIPCC__) || defined(__NVCC__)
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#endif
...
...
python/paddle/fluid/tests/unittests/op_test.py
浏览文件 @
946dbdae
...
...
@@ -278,6 +278,9 @@ class OpTest(unittest.TestCase):
def
is_mkldnn_op_test
():
return
hasattr
(
cls
,
"use_mkldnn"
)
and
cls
.
use_mkldnn
==
True
def
is_rocm_op_test
():
return
core
.
is_compiled_with_rocm
()
if
not
hasattr
(
cls
,
"op_type"
):
raise
AssertionError
(
"This test do not have op_type in class attrs, "
...
...
@@ -298,7 +301,8 @@ class OpTest(unittest.TestCase):
and
cls
.
op_type
not
in
op_accuracy_white_list
.
NO_FP64_CHECK_GRAD_OP_LIST
\
and
not
hasattr
(
cls
,
'exist_fp64_check_grad'
)
\
and
not
is_xpu_op_test
()
\
and
not
is_mkldnn_op_test
():
and
not
is_mkldnn_op_test
()
\
and
not
is_rocm_op_test
():
raise
AssertionError
(
"This test of %s op needs check_grad with fp64 precision."
%
cls
.
op_type
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录