Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
946dbdae
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
946dbdae
编写于
3月 03, 2021
作者:
Q
Qi Li
提交者:
GitHub
3月 03, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[ROCM] update fluid operators for rocm (part6), test=develop (#31301)
上级
1cbccfa5
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
350 addition
and
36 deletion
+350
-36
paddle/fluid/operators/activation_cudnn.cu.cc
paddle/fluid/operators/activation_cudnn.cu.cc
+4
-0
paddle/fluid/operators/activation_cudnn_op.cu.cc
paddle/fluid/operators/activation_cudnn_op.cu.cc
+71
-9
paddle/fluid/operators/activation_op.cc
paddle/fluid/operators/activation_op.cc
+0
-3
paddle/fluid/operators/affine_channel_op.cu
paddle/fluid/operators/affine_channel_op.cu
+8
-0
paddle/fluid/operators/affine_grid_cudnn_op.cu.cc
paddle/fluid/operators/affine_grid_cudnn_op.cu.cc
+5
-0
paddle/fluid/operators/affine_grid_op.cc
paddle/fluid/operators/affine_grid_op.cc
+5
-2
paddle/fluid/operators/allclose_op.cu
paddle/fluid/operators/allclose_op.cu
+4
-1
paddle/fluid/operators/arg_min_max_op_base.cu.h
paddle/fluid/operators/arg_min_max_op_base.cu.h
+8
-2
paddle/fluid/operators/argsort_op.cu
paddle/fluid/operators/argsort_op.cu
+17
-1
paddle/fluid/operators/batch_fc_op.cu
paddle/fluid/operators/batch_fc_op.cu
+2
-3
paddle/fluid/operators/batch_norm_op.cu
paddle/fluid/operators/batch_norm_op.cu
+173
-6
paddle/fluid/operators/bce_loss_op.cu
paddle/fluid/operators/bce_loss_op.cu
+0
-1
paddle/fluid/operators/math/sequence_padding_test.cc
paddle/fluid/operators/math/sequence_padding_test.cc
+1
-1
paddle/fluid/operators/math/sequence_pooling_test.cc
paddle/fluid/operators/math/sequence_pooling_test.cc
+1
-1
paddle/fluid/operators/math/sequence_scale.cu
paddle/fluid/operators/math/sequence_scale.cu
+8
-0
paddle/fluid/operators/math/softmax.cu
paddle/fluid/operators/math/softmax.cu
+36
-3
paddle/fluid/operators/math/softmax.h
paddle/fluid/operators/math/softmax.h
+1
-1
paddle/fluid/operators/pool_op.h
paddle/fluid/operators/pool_op.h
+1
-1
python/paddle/fluid/tests/unittests/op_test.py
python/paddle/fluid/tests/unittests/op_test.py
+5
-1
未找到文件。
paddle/fluid/operators/activation_cudnn.cu.cc
浏览文件 @
946dbdae
...
...
@@ -14,7 +14,11 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_desc.h"
#else
#include "paddle/fluid/platform/cudnn_desc.h"
#endif
namespace
paddle
{
namespace
operators
{
...
...
paddle/fluid/operators/activation_cudnn_op.cu.cc
浏览文件 @
946dbdae
...
...
@@ -14,7 +14,11 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_desc.h"
#else
#include "paddle/fluid/platform/cudnn_desc.h"
#endif
namespace
paddle
{
namespace
platform
{
...
...
@@ -29,35 +33,71 @@ using platform::ActivationDescriptor;
using
platform
::
TensorDescriptor
;
using
platform
::
CUDADeviceContext
;
#ifdef PADDLE_WITH_HIP
#define GPUDNN_ACTIVATION_RELU miopenActivationRELU
#define GPUDNN_ACTIVATION_CLIPPED_RELU miopenActivationCLIPPEDRELU
#define GPUDNN_ACTIVATION_SIGMOID miopenActivationLOGISTIC
#define GPUDNN_ACTIVATION_TANH miopenActivationTANH
#else
#define GPUDNN_ACTIVATION_RELU CUDNN_ACTIVATION_RELU
#define GPUDNN_ACTIVATION_CLIPPED_RELU CUDNN_ACTIVATION_CLIPPED_RELU
#define GPUDNN_ACTIVATION_SIGMOID CUDNN_ACTIVATION_SIGMOID
#define GPUDNN_ACTIVATION_TANH CUDNN_ACTIVATION_TANH
#endif
template
<
typename
T
>
struct
CudnnActivationFunctor
{
using
ELEMENT_TYPE
=
T
;
#ifdef PADDLE_WITH_HIP
CudnnActivationFunctor
(
const
CUDADeviceContext
&
ctx
,
const
T
&
c
,
const
miopenActivationMode_t
&
m
)
:
ctx_
(
ctx
),
coef_
(
c
),
mode_
(
m
)
{}
#else
CudnnActivationFunctor
(
const
CUDADeviceContext
&
ctx
,
const
T
&
c
,
const
cudnnActivationMode_t
&
m
)
:
ctx_
(
ctx
),
coef_
(
c
),
mode_
(
m
)
{}
#endif
void
operator
()(
const
Tensor
&
x
,
Tensor
*
out
)
{
ActivationDescriptor
act_desc
;
act_desc
.
set
(
mode_
,
coef_
);
TensorDescriptor
x_desc
,
out_desc
;
x_desc
.
set
(
x
);
out_desc
.
set
(
GET_DATA_SAFELY
(
out
,
"Output"
,
"Out"
,
"CudnnActivation"
));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenActivationForward
(
ctx_
.
cudnn_handle
(),
act_desc
.
desc
(),
platform
::
CudnnDataType
<
T
>::
kOne
(),
x_desc
.
desc
(),
x
.
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
out_desc
.
desc
(),
out
->
mutable_data
<
T
>
(
ctx_
.
GetPlace
())));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnActivationForward
(
ctx_
.
cudnn_handle
(),
act_desc
.
desc
(),
platform
::
CudnnDataType
<
T
>::
kOne
(),
x_desc
.
desc
(),
x
.
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
out_desc
.
desc
(),
out
->
mutable_data
<
T
>
(
ctx_
.
GetPlace
())));
#endif
}
const
CUDADeviceContext
&
ctx_
;
const
T
coef_
;
#ifdef PADDLE_WITH_HIP
const
miopenActivationMode_t
mode_
;
#else
const
cudnnActivationMode_t
mode_
;
#endif
};
template
<
typename
T
>
struct
CudnnActivationGradFunctor
{
using
ELEMENT_TYPE
=
T
;
#ifdef PADDLE_WITH_HIP
CudnnActivationGradFunctor
(
const
CUDADeviceContext
&
ctx
,
const
T
&
c
,
const
miopenActivationMode_t
&
m
)
:
ctx_
(
ctx
),
coef_
(
c
),
mode_
(
m
)
{}
#else
CudnnActivationGradFunctor
(
const
CUDADeviceContext
&
ctx
,
const
T
&
c
,
const
cudnnActivationMode_t
&
m
)
:
ctx_
(
ctx
),
coef_
(
c
),
mode_
(
m
)
{}
#endif
void
operator
()(
const
Tensor
&
x
,
const
Tensor
&
out
,
const
Tensor
dout
,
Tensor
*
dx
)
{
ActivationDescriptor
act_desc
;
...
...
@@ -67,27 +107,40 @@ struct CudnnActivationGradFunctor {
out_desc
.
set
(
out
);
dout_desc
.
set
(
dout
);
dx_desc
.
set
(
GET_DATA_SAFELY
(
dx
,
"Output"
,
"X@GRAD"
,
"CudnnActivationGrad"
));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenActivationBackward
(
ctx_
.
cudnn_handle
(),
act_desc
.
desc
(),
platform
::
CudnnDataType
<
T
>::
kOne
(),
out_desc
.
desc
(),
out
.
data
<
T
>
(),
dout_desc
.
desc
(),
dout
.
data
<
T
>
(),
x_desc
.
desc
(),
x
.
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
dx_desc
.
desc
(),
dx
->
mutable_data
<
T
>
(
ctx_
.
GetPlace
())));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnActivationBackward
(
ctx_
.
cudnn_handle
(),
act_desc
.
desc
(),
platform
::
CudnnDataType
<
T
>::
kOne
(),
out_desc
.
desc
(),
out
.
data
<
T
>
(),
dout_desc
.
desc
(),
dout
.
data
<
T
>
(),
x_desc
.
desc
(),
x
.
data
<
T
>
(),
platform
::
CudnnDataType
<
T
>::
kZero
(),
dx_desc
.
desc
(),
dx
->
mutable_data
<
T
>
(
ctx_
.
GetPlace
())));
#endif
}
const
CUDADeviceContext
&
ctx_
;
const
T
coef_
;
#ifdef PADDLE_WITH_HIP
const
miopenActivationMode_t
mode_
;
#else
const
cudnnActivationMode_t
mode_
;
#endif
};
template
<
typename
T
>
struct
CudnnReluFunctor
:
public
CudnnActivationFunctor
<
T
>
{
explicit
CudnnReluFunctor
(
const
CUDADeviceContext
&
ctx
)
:
CudnnActivationFunctor
<
T
>
(
ctx
,
0.0
,
C
UDNN_ACTIVATION_RELU
)
{}
:
CudnnActivationFunctor
<
T
>
(
ctx
,
0.0
,
GP
UDNN_ACTIVATION_RELU
)
{}
};
template
<
typename
T
>
struct
CudnnReluGradFunctor
:
public
CudnnActivationGradFunctor
<
T
>
{
explicit
CudnnReluGradFunctor
(
const
CUDADeviceContext
&
ctx
)
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
0.0
,
C
UDNN_ACTIVATION_RELU
)
{}
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
0.0
,
GP
UDNN_ACTIVATION_RELU
)
{}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepOut
;
}
};
...
...
@@ -95,13 +148,13 @@ struct CudnnReluGradFunctor : public CudnnActivationGradFunctor<T> {
template
<
typename
T
>
struct
CudnnRelu6Functor
:
public
CudnnActivationFunctor
<
T
>
{
explicit
CudnnRelu6Functor
(
const
CUDADeviceContext
&
ctx
)
:
CudnnActivationFunctor
<
T
>
(
ctx
,
6.0
,
C
UDNN_ACTIVATION_CLIPPED_RELU
)
{}
:
CudnnActivationFunctor
<
T
>
(
ctx
,
6.0
,
GP
UDNN_ACTIVATION_CLIPPED_RELU
)
{}
};
template
<
typename
T
>
struct
CudnnRelu6GradFunctor
:
public
CudnnActivationGradFunctor
<
T
>
{
explicit
CudnnRelu6GradFunctor
(
const
CUDADeviceContext
&
ctx
)
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
6.0
,
CUDNN_ACTIVATION_CLIPPED_RELU
)
{
}
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
6.0
,
GPUDNN_ACTIVATION_CLIPPED_RELU
)
{
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepOut
;
}
};
...
...
@@ -109,12 +162,12 @@ struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor<T> {
template
<
typename
T
>
struct
CudnnSigmoidFunctor
:
public
CudnnActivationFunctor
<
T
>
{
explicit
CudnnSigmoidFunctor
(
const
CUDADeviceContext
&
ctx
)
:
CudnnActivationFunctor
<
T
>
(
ctx
,
0.0
,
C
UDNN_ACTIVATION_SIGMOID
)
{}
:
CudnnActivationFunctor
<
T
>
(
ctx
,
0.0
,
GP
UDNN_ACTIVATION_SIGMOID
)
{}
};
template
<
typename
T
>
struct
CudnnSigmoidGradFunctor
:
public
CudnnActivationGradFunctor
<
T
>
{
explicit
CudnnSigmoidGradFunctor
(
const
CUDADeviceContext
&
ctx
)
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
0.0
,
C
UDNN_ACTIVATION_SIGMOID
)
{}
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
0.0
,
GP
UDNN_ACTIVATION_SIGMOID
)
{}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepOut
;
}
};
...
...
@@ -122,12 +175,12 @@ struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor<T> {
template
<
typename
T
>
struct
CudnnTanhFunctor
:
public
CudnnActivationFunctor
<
T
>
{
explicit
CudnnTanhFunctor
(
const
CUDADeviceContext
&
ctx
)
:
CudnnActivationFunctor
<
T
>
(
ctx
,
0.0
,
C
UDNN_ACTIVATION_TANH
)
{}
:
CudnnActivationFunctor
<
T
>
(
ctx
,
0.0
,
GP
UDNN_ACTIVATION_TANH
)
{}
};
template
<
typename
T
>
struct
CudnnTanhGradFunctor
:
public
CudnnActivationGradFunctor
<
T
>
{
explicit
CudnnTanhGradFunctor
(
const
CUDADeviceContext
&
ctx
)
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
0.0
,
C
UDNN_ACTIVATION_TANH
)
{}
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
0.0
,
GP
UDNN_ACTIVATION_TANH
)
{}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepOut
;
}
};
...
...
@@ -183,6 +236,14 @@ namespace ops = paddle::operators;
__macro(sigmoid, CudnnSigmoidFunctor, CudnnSigmoidGradFunctor); \
__macro(tanh, CudnnTanhFunctor, CudnnTanhGradFunctor)
#ifdef PADDLE_WITH_HIP
#define REGISTER_ACTIVATION_CUDNN_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_KERNEL(act_type, CUDNN, plat::CUDAPlace, \
ops::CudnnActivationKernel<ops::functor<float>>); \
REGISTER_OP_KERNEL( \
act_type##_grad, CUDNN, plat::CUDAPlace, \
ops::CudnnActivationGradKernel<ops::grad_functor<float>>);
#else
#define REGISTER_ACTIVATION_CUDNN_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_KERNEL(act_type, CUDNN, plat::CUDAPlace, \
ops::CudnnActivationKernel<ops::functor<float>>, \
...
...
@@ -191,5 +252,6 @@ namespace ops = paddle::operators;
act_type##_grad, CUDNN, plat::CUDAPlace, \
ops::CudnnActivationGradKernel<ops::grad_functor<float>>, \
ops::CudnnActivationGradKernel<ops::grad_functor<double>>);
#endif
FOR_EACH_CUDNN_OP_FUNCTOR
(
REGISTER_ACTIVATION_CUDNN_KERNEL
);
paddle/fluid/operators/activation_op.cc
浏览文件 @
946dbdae
...
...
@@ -24,9 +24,6 @@ limitations under the License. */
#include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
#include "paddle/fluid/platform/port.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
DECLARE_bool
(
use_mkldnn
);
...
...
paddle/fluid/operators/affine_channel_op.cu
浏览文件 @
946dbdae
...
...
@@ -12,7 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cuda_primitives.h"
...
...
paddle/fluid/operators/affine_grid_cudnn_op.cu.cc
浏览文件 @
946dbdae
...
...
@@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_WITH_HIP
// HIP not support cudnnSpatialTfGridGeneratorForward
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cudnn_helper.h"
...
...
@@ -121,3 +124,5 @@ REGISTER_OP_KERNEL(affine_grid, CUDNN, plat::CUDAPlace,
REGISTER_OP_KERNEL
(
affine_grid_grad
,
CUDNN
,
plat
::
CUDAPlace
,
paddle
::
operators
::
CUDNNAffineGridGradOpKernel
<
float
>
,
paddle
::
operators
::
CUDNNAffineGridGradOpKernel
<
double
>
);
#endif // not PADDLE_WITH_HIP
paddle/fluid/operators/affine_grid_op.cc
浏览文件 @
946dbdae
...
...
@@ -21,6 +21,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
...
...
@@ -109,7 +112,7 @@ class AffineGridOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library
{
framework
::
LibraryType
::
kPlain
};
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library
=
framework
::
LibraryType
::
kCUDNN
;
}
...
...
@@ -226,7 +229,7 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
...
...
paddle/fluid/operators/allclose_op.cu
浏览文件 @
946dbdae
...
...
@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/allclose_op.h"
...
...
@@ -67,7 +66,11 @@ struct AllcloseFunctor<platform::CUDADeviceContext, T> {
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
grid
=
(
grid
>
block
)
?
block
:
grid
;
#ifdef PADDLE_WITH_HIP
hipMemset
(
out_data
,
true
,
sizeof
(
bool
));
#else
cudaMemset
(
out_data
,
true
,
sizeof
(
bool
));
#endif
AllcloseCUDAKernel
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
other_data
,
rtol
,
atol
,
equal_nan
,
num
,
out_data
);
}
...
...
paddle/fluid/operators/arg_min_max_op_base.cu.h
浏览文件 @
946dbdae
...
...
@@ -14,9 +14,15 @@ limitations under the License. */
#pragma once
#if
def __NVCC__
#if
defined(__NVCC__) || defined(__HIPCC__)
#include <cub/cub.cuh>
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include <limits>
#include <string>
#include <typeinfo>
...
...
paddle/fluid/operators/argsort_op.cu
浏览文件 @
946dbdae
...
...
@@ -16,13 +16,28 @@ limitations under the License. */
#include <thrust/execution_policy.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/argsort_op.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#ifdef __HIPCC__
namespace
rocprim
{
namespace
detail
{
template
<
>
struct
radix_key_codec_base
<
paddle
::
platform
::
float16
>
:
radix_key_codec_integral
<
paddle
::
platform
::
float16
,
uint16_t
>
{};
}
// namespace detail
}
// namespace rocprim
#else
// set cub base traits in order to handle float16
namespace
cub
{
template
<
>
...
...
@@ -30,6 +45,7 @@ struct NumericTraits<paddle::platform::float16>
:
BaseTraits
<
FLOATING_POINT
,
true
,
false
,
uint16_t
,
paddle
::
platform
::
float16
>
{};
}
// namespace cub
#endif
namespace
paddle
{
namespace
operators
{
...
...
@@ -139,7 +155,7 @@ void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input,
cub
::
CountingInputIterator
<
IndType
>>
segment_offsets_t
(
counting_iter
,
SegmentOffsetIter
(
num_cols
));
cuda
Error_t
err
;
gpu
Error_t
err
;
if
(
descending
)
{
err
=
cub
::
DeviceSegmentedRadixSort
::
SortPairsDescending
(
nullptr
,
temp_storage_bytes
,
inp
,
sorted_out_ptr
,
...
...
paddle/fluid/operators/batch_fc_op.cu
浏览文件 @
946dbdae
...
...
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cublas.h>
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/batch_fc_op.h"
...
...
@@ -42,7 +41,7 @@ __global__ void add_bias_kernel(T* data, int slot_pairs_num, int ins_num,
}
template
<
typename
T
>
void
add_bias
(
cuda
Stream_t
stream
,
T
*
data
,
int
slot_pairs_num
,
int
ins_num
,
void
add_bias
(
gpu
Stream_t
stream
,
T
*
data
,
int
slot_pairs_num
,
int
ins_num
,
int
out_dim
,
const
T
*
bias
)
{
add_bias_kernel
<<<
GET_BLOCKS
(
slot_pairs_num
*
ins_num
*
out_dim
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
data
,
slot_pairs_num
,
...
...
@@ -65,7 +64,7 @@ __global__ void add_bias_grad_kernel(const T* dout_data, int slot_pairs_num,
}
template
<
typename
T
>
void
add_bias_grad
(
cuda
Stream_t
stream
,
const
T
*
dout_data
,
int
slot_pairs_num
,
void
add_bias_grad
(
gpu
Stream_t
stream
,
const
T
*
dout_data
,
int
slot_pairs_num
,
int
ins_num
,
int
out_dim
,
T
*
db_data
)
{
add_bias_grad_kernel
<<<
GET_BLOCKS
(
slot_pairs_num
*
out_dim
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
dout_data
,
slot_pairs_num
,
ins_num
,
...
...
paddle/fluid/operators/batch_norm_op.cu
浏览文件 @
946dbdae
...
...
@@ -16,12 +16,17 @@ limitations under the License. */
#include <cfloat>
#include <string>
#include <vector>
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/norm_utils.cu.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
DECLARE_bool
(
cudnn_batchnorm_spatial_persistent
);
...
...
@@ -73,6 +78,11 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
ExtractNCWHD
(
x_dims
,
data_layout
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
auto
dtype
=
platform
::
CudnnDataType
<
T
>::
type
;
#ifdef PADDLE_WITH_HIP
// HIP do not support compute format of NHWC
auto
compute_format
=
DataLayout
::
kNCHW
;
#else
const
bool
fast_nhwc_batch_norm
=
test_mode
||
(
dtype
==
CUDNN_DATA_HALF
&&
FLAGS_cudnn_batchnorm_spatial_persistent
);
...
...
@@ -81,6 +91,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
fast_nhwc_batch_norm
&&
data_layout
==
DataLayout
::
kNHWC
?
DataLayout
::
kNHWC
:
DataLayout
::
kNCHW
;
#endif
Tensor
transformed_x
(
x
->
type
());
Tensor
transformed_y
(
y
->
type
());
...
...
@@ -98,7 +109,17 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
transformed_y
.
ShareDataWith
(
*
y
);
}
// ------------------- cudnn descriptors ---------------------
// ------------------- cudnn descriptors ---------------------
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t
data_desc_
;
miopenTensorDescriptor_t
bn_param_desc_
;
miopenBatchNormMode_t
mode_
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
bn_param_desc_
));
#else
cudnnTensorDescriptor_t
data_desc_
;
cudnnTensorDescriptor_t
bn_param_desc_
;
cudnnBatchNormMode_t
mode_
;
...
...
@@ -107,6 +128,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
bn_param_desc_
));
#endif
if
(
epsilon
<=
CUDNN_BN_MIN_EPSILON
-
FLT_EPSILON
)
{
LOG
(
ERROR
)
<<
"Provided epsilon is smaller than "
...
...
@@ -114,7 +136,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
<<
"CUDNN_BN_MIN_EPSILON instead."
;
}
epsilon
=
std
::
max
(
epsilon
,
CUDNN_BN_MIN_EPSILON
);
#if CUDNN_VERSION_MIN(7, 0, 1)
#ifdef PADDLE_WITH_HIP
mode_
=
miopenBNSpatial
;
#elif CUDNN_VERSION_MIN(7, 0, 1)
if
(
FLAGS_cudnn_batchnorm_spatial_persistent
)
{
mode_
=
CUDNN_BATCHNORM_SPATIAL_PERSISTENT
;
}
else
{
...
...
@@ -134,6 +159,17 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
dims
=
{
N
,
C
,
H
,
W
,
D
};
strides
=
{
H
*
W
*
D
*
C
,
1
,
W
*
D
*
C
,
D
*
C
,
C
};
}
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetTensorDescriptor
(
data_desc_
,
CudnnDataType
<
T
>::
type
,
x_dims
.
size
()
>
3
?
x_dims
.
size
()
:
4
,
const_cast
<
int
*>
(
dims
.
data
()),
const_cast
<
int
*>
(
strides
.
data
())));
// Note: PERSISTENT not implemented for inference
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDeriveBNTensorDescriptor
(
bn_param_desc_
,
data_desc_
,
test_mode
?
miopenBNSpatial
:
mode_
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
data_desc_
,
CudnnDataType
<
T
>::
type
,
x_dims
.
size
()
>
3
?
x_dims
.
size
()
:
4
,
dims
.
data
(),
strides
.
data
()));
...
...
@@ -142,6 +178,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
platform
::
dynload
::
cudnnDeriveBNTensorDescriptor
(
bn_param_desc_
,
data_desc_
,
test_mode
?
CUDNN_BATCHNORM_SPATIAL
:
mode_
));
#endif
const
auto
*
scale
=
ctx
.
Input
<
Tensor
>
(
"Scale"
);
const
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
...
...
@@ -188,6 +225,30 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
"variance is [%d], the dimensions of variance is [%s]."
,
C
,
est_var
->
dims
()[
0
],
est_var
->
dims
()));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenBatchNormalizationForwardInference
(
handle
,
miopenBNSpatial
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
CudnnDataType
<
T
>::
kOne
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
CudnnDataType
<
T
>::
kZero
())),
data_desc_
,
static_cast
<
const
void
*>
(
transformed_x
.
template
data
<
T
>()),
data_desc_
,
static_cast
<
void
*>
(
transformed_y
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
())),
bn_param_desc_
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
scale
->
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
bias
->
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
est_var
->
template
data
<
BatchNormParamType
<
T
>
>
())),
epsilon
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnBatchNormalizationForwardInference
(
handle
,
...
...
@@ -200,6 +261,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
bias
->
template
data
<
BatchNormParamType
<
T
>
>
(),
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
(),
est_var
->
template
data
<
BatchNormParamType
<
T
>
>
(),
epsilon
));
#endif
}
else
{
// if MomentumTensor is set, use MomentumTensor value, momentum
// is only used in this training branch
...
...
@@ -302,6 +364,36 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
reserve_space_size
));
#endif // CUDNN_VERSION_MIN(7, 4, 1)
if
(
!
called
)
{
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenBatchNormalizationForwardTraining
(
handle
,
mode_
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
CudnnDataType
<
T
>::
kOne
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
CudnnDataType
<
T
>::
kZero
())),
data_desc_
,
static_cast
<
const
void
*>
(
transformed_x
.
template
data
<
T
>()),
data_desc_
,
static_cast
<
void
*>
(
transformed_y
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
())),
bn_param_desc_
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
scale
->
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
bias
->
template
data
<
BatchNormParamType
<
T
>
>
())),
this_factor
,
static_cast
<
void
*>
(
mean_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())),
static_cast
<
void
*>
(
variance_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())),
epsilon
,
static_cast
<
void
*>
(
saved_mean
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())),
static_cast
<
void
*>
(
saved_variance
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()))));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnBatchNormalizationForwardTraining
(
handle
,
mode_
,
CudnnDataType
<
T
>::
kOne
(),
...
...
@@ -319,6 +411,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
ctx
.
GetPlace
()),
saved_variance
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())));
#endif
}
}
}
...
...
@@ -329,11 +422,19 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
TransToChannelLast
<
paddle
::
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
&
transformed_y
,
y
);
}
#ifdef PADDLE_WITH_HIP
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
bn_param_desc_
));
#else
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
bn_param_desc_
));
#endif
}
};
...
...
@@ -416,7 +517,7 @@ class InplaceHelper {
const
BatchNormParamType
<
T
>
*
mean
,
const
BatchNormParamType
<
T
>
*
variance
,
double
epsilon
,
int
C
,
int
M
,
const
int
num
,
const
T
*
y
,
int
grid2
,
const
int
block
,
const
cuda
Stream_t
&
stream
)
{
const
gpu
Stream_t
&
stream
)
{
PADDLE_ENFORCE_EQ
(
x
,
y
,
platform
::
errors
::
InvalidArgument
(
"X and Y should be inplaced in inplace mode"
));
KeBNRestoreData
<<<
grid2
,
block
,
0
,
stream
>>>
(
...
...
@@ -566,6 +667,10 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
auto
dtype
=
platform
::
CudnnDataType
<
T
>::
type
;
const
auto
*
reserve_space
=
ctx
.
Input
<
Tensor
>
(
"ReserveSpace"
);
#ifdef PADDLE_WITH_HIP
// HIP do not support compute format of NHWC
auto
compute_format
=
DataLayout
::
kNCHW
;
#else
const
bool
fast_nhwc_batch_norm
=
dtype
==
CUDNN_DATA_HALF
&&
FLAGS_cudnn_batchnorm_spatial_persistent
&&
reserve_space
!=
nullptr
;
...
...
@@ -573,6 +678,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
fast_nhwc_batch_norm
&&
data_layout
==
DataLayout
::
kNHWC
?
DataLayout
::
kNHWC
:
DataLayout
::
kNCHW
;
#endif
Tensor
transformed_x
(
x
->
type
());
Tensor
transformed_d_y
(
d_y
->
type
());
...
...
@@ -626,7 +732,17 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
return
;
}
// ------------------- cudnn descriptors ---------------------
// ------------------- cudnn descriptors ---------------------
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t
data_desc_
;
miopenTensorDescriptor_t
bn_param_desc_
;
miopenBatchNormMode_t
mode_
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
bn_param_desc_
));
#else
cudnnTensorDescriptor_t
data_desc_
;
cudnnTensorDescriptor_t
bn_param_desc_
;
cudnnBatchNormMode_t
mode_
;
...
...
@@ -635,13 +751,16 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
bn_param_desc_
));
#endif
if
(
epsilon
<=
CUDNN_BN_MIN_EPSILON
-
FLT_EPSILON
)
{
LOG
(
ERROR
)
<<
"Provided epsilon is smaller than "
<<
"CUDNN_BN_MIN_EPSILON. Setting it to "
<<
"CUDNN_BN_MIN_EPSILON instead."
;
}
epsilon
=
std
::
max
(
epsilon
,
CUDNN_BN_MIN_EPSILON
);
#if CUDNN_VERSION_MIN(7, 0, 1)
#ifdef PADDLE_WITH_HIP
mode_
=
miopenBNSpatial
;
#elif CUDNN_VERSION_MIN(7, 0, 1)
if
(
FLAGS_cudnn_batchnorm_spatial_persistent
)
{
mode_
=
CUDNN_BATCHNORM_SPATIAL_PERSISTENT
;
}
else
{
...
...
@@ -651,12 +770,22 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
mode_
=
CUDNN_BATCHNORM_SPATIAL
;
#endif // CUDNN_VERSION_MIN(7, 0, 1)
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetTensorDescriptor
(
data_desc_
,
CudnnDataType
<
T
>::
type
,
x_dims
.
size
()
>
3
?
x_dims
.
size
()
:
4
,
const_cast
<
int
*>
(
dims
.
data
()),
const_cast
<
int
*>
(
strides
.
data
())));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDeriveBNTensorDescriptor
(
bn_param_desc_
,
data_desc_
,
mode_
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
data_desc_
,
CudnnDataType
<
T
>::
type
,
x_dims
.
size
()
>
3
?
x_dims
.
size
()
:
4
,
dims
.
data
(),
strides
.
data
()));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnDeriveBNTensorDescriptor
(
bn_param_desc_
,
data_desc_
,
mode_
));
#endif
const
auto
*
saved_mean
=
ctx
.
Input
<
Tensor
>
(
"SavedMean"
);
const
auto
*
saved_var
=
ctx
.
Input
<
Tensor
>
(
"SavedVariance"
);
...
...
@@ -741,6 +870,22 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
/*reserveSpaceSizeInBytes=*/
reserve_space_size
));
#endif // CUDNN_VERSION_MIN(7, 4, 1)
if
(
!
called
)
{
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenBatchNormalizationBackward
(
dev_ctx
.
cudnn_handle
(),
mode_
,
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
transformed_x
.
template
data
<
T
>(),
data_desc_
,
transformed_d_y
.
template
data
<
T
>(),
data_desc_
,
transformed_d_x
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
bn_param_desc_
,
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
d_scale
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
d_bias
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
epsilon
,
saved_mean_data
,
saved_var_data
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnBatchNormalizationBackward
(
dev_ctx
.
cudnn_handle
(),
mode_
,
CudnnDataType
<
T
>::
kOne
(),
...
...
@@ -755,6 +900,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
d_bias
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
epsilon
,
saved_mean_data
,
saved_var_data
));
#endif
}
if
(
data_layout
==
DataLayout
::
kNHWC
&&
...
...
@@ -784,11 +930,19 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
}
}
#ifdef PADDLE_WITH_HIP
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
bn_param_desc_
));
#else
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
bn_param_desc_
));
#endif
}
else
{
const
auto
*
running_mean
=
ctx
.
Input
<
Tensor
>
(
"Mean"
);
const
auto
*
running_var
=
ctx
.
Input
<
Tensor
>
(
"Variance"
);
...
...
@@ -886,6 +1040,18 @@ class BatchNormDoubleGradKernel<platform::CUDADeviceContext, T>
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL
(
batch_norm
,
ops
::
BatchNormKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
BatchNormKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
batch_norm_grad
,
ops
::
BatchNormGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
BatchNormGradKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
batch_norm_grad_grad
,
ops
::
BatchNormDoubleGradKernel
<
plat
::
CUDADeviceContext
,
float
>
);
#else
REGISTER_OP_CUDA_KERNEL
(
batch_norm
,
ops
::
BatchNormKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
BatchNormKernel
<
plat
::
CUDADeviceContext
,
double
>
,
...
...
@@ -898,3 +1064,4 @@ REGISTER_OP_CUDA_KERNEL(
batch_norm_grad_grad
,
ops
::
BatchNormDoubleGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
BatchNormDoubleGradKernel
<
plat
::
CUDADeviceContext
,
double
>
);
#endif
paddle/fluid/operators/bce_loss_op.cu
浏览文件 @
946dbdae
...
...
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "cub/cub.cuh"
#include "paddle/fluid/operators/bce_loss_op.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/cuda_primitives.h"
...
...
paddle/fluid/operators/math/sequence_padding_test.cc
浏览文件 @
946dbdae
...
...
@@ -105,7 +105,7 @@ TEST(Seq2BatchPadding, CPU) {
128
);
}
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST
(
SequencePadding
,
CUDA
)
{
auto
place
=
paddle
::
platform
::
CUDAPlace
(
0
);
auto
*
context
=
static_cast
<
paddle
::
platform
::
CUDADeviceContext
*>
(
...
...
paddle/fluid/operators/math/sequence_pooling_test.cc
浏览文件 @
946dbdae
...
...
@@ -123,7 +123,7 @@ TEST(SequencePoolingGrad, CPU_SUM) {
lod2
,
128
);
}
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST
(
SequencePoolingGrad
,
CUDA_SUM
)
{
auto
place
=
paddle
::
platform
::
CUDAPlace
(
0
);
auto
*
context
=
static_cast
<
paddle
::
platform
::
CUDADeviceContext
*>
(
...
...
paddle/fluid/operators/math/sequence_scale.cu
浏览文件 @
946dbdae
...
...
@@ -44,10 +44,18 @@ class ScaleLoDTensorFunctor<platform::CUDADeviceContext, T> {
framework
::
LoD
abs_offset_lod
=
framework
::
ToAbsOffset
(
lod
);
T
*
seq_data
=
seq
->
mutable_data
<
T
>
(
context
.
GetPlace
());
#ifdef PADDLE_WITH_HIP
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
SequenceScaleKernel
<
T
,
PADDLE_CUDA_NUM_THREADS
>
),
dim3
(
num_seq
),
dim3
(
PADDLE_CUDA_NUM_THREADS
),
0
,
context
.
stream
(),
seq_data
,
abs_offset_lod
[
level
].
CUDAMutableData
(
context
.
GetPlace
()),
scales
,
seq_width
);
#else
SequenceScaleKernel
<
T
,
PADDLE_CUDA_NUM_THREADS
><<<
num_seq
,
PADDLE_CUDA_NUM_THREADS
,
0
,
context
.
stream
()
>>>
(
seq_data
,
abs_offset_lod
[
level
].
CUDAMutableData
(
context
.
GetPlace
()),
scales
,
seq_width
);
#endif
}
};
...
...
paddle/fluid/operators/math/softmax.cu
浏览文件 @
946dbdae
...
...
@@ -16,7 +16,11 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/math/softmax_impl.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#else
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
...
...
@@ -45,6 +49,16 @@ void SoftmaxCUDNNFunctor<T>::operator()(
if
(
cudnn_tensor_dims
.
size
()
<=
2
)
{
cudnn_tensor_dims
.
resize
(
4
,
1
);
}
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t
cudnn_x_desc
=
xDesc
.
descriptor
<
T
>
(
layout
,
cudnn_tensor_dims
);
miopenTensorDescriptor_t
cudnn_y_desc
=
xDesc
.
descriptor
<
T
>
(
layout
,
cudnn_tensor_dims
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxForward
(
context
.
cudnn_handle
(),
CudnnDataType
<
T
>::
kOne
(),
cudnn_x_desc
,
X
->
data
<
T
>
(),
CudnnDataType
<
T
>::
kZero
(),
cudnn_y_desc
,
Y
->
mutable_data
<
T
>
(
context
.
GetPlace
())));
#else
cudnnTensorDescriptor_t
cudnn_x_desc
=
xDesc
.
descriptor
<
T
>
(
layout
,
cudnn_tensor_dims
);
cudnnTensorDescriptor_t
cudnn_y_desc
=
...
...
@@ -54,6 +68,7 @@ void SoftmaxCUDNNFunctor<T>::operator()(
CUDNN_SOFTMAX_MODE_INSTANCE
,
CudnnDataType
<
T
>::
kOne
(),
cudnn_x_desc
,
X
->
data
<
T
>
(),
CudnnDataType
<
T
>::
kZero
(),
cudnn_y_desc
,
Y
->
mutable_data
<
T
>
(
context
.
GetPlace
())));
#endif
}
template
<
typename
T
>
...
...
@@ -74,6 +89,19 @@ void SoftmaxGradCUDNNFunctor<T>::operator()(
if
(
cudnn_tensor_dims
.
size
()
<=
2
)
{
cudnn_tensor_dims
.
resize
(
4
,
1
);
}
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t
cudnn_y_desc
=
yDesc
.
descriptor
<
T
>
(
layout
,
cudnn_tensor_dims
);
miopenTensorDescriptor_t
cudnn_xgrad_desc
=
dxDesc
.
descriptor
<
T
>
(
layout
,
cudnn_tensor_dims
);
miopenTensorDescriptor_t
cudnn_ygrad_desc
=
dyDesc
.
descriptor
<
T
>
(
layout
,
cudnn_tensor_dims
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSoftmaxBackward
(
context
.
cudnn_handle
(),
CudnnDataType
<
T
>::
kOne
(),
cudnn_y_desc
,
Y
->
data
<
T
>
(),
cudnn_ygrad_desc
,
YGrad
->
data
<
T
>
(),
CudnnDataType
<
T
>::
kZero
(),
cudnn_xgrad_desc
,
XGrad
->
mutable_data
<
T
>
(
context
.
GetPlace
())));
#else
cudnnTensorDescriptor_t
cudnn_y_desc
=
yDesc
.
descriptor
<
T
>
(
layout
,
cudnn_tensor_dims
);
cudnnTensorDescriptor_t
cudnn_xgrad_desc
=
...
...
@@ -86,15 +114,20 @@ void SoftmaxGradCUDNNFunctor<T>::operator()(
Y
->
data
<
T
>
(),
cudnn_ygrad_desc
,
YGrad
->
data
<
T
>
(),
CudnnDataType
<
T
>::
kZero
(),
cudnn_xgrad_desc
,
XGrad
->
mutable_data
<
T
>
(
context
.
GetPlace
())));
#endif
}
template
class
SoftmaxCUDNNFunctor
<
platform
::
float16
>;
template
class
SoftmaxCUDNNFunctor
<
float
>;
template
class
SoftmaxCUDNNFunctor
<
double
>;
template
class
SoftmaxCUDNNFunctor
<
platform
::
float16
>;
template
class
SoftmaxGradCUDNNFunctor
<
float
>;
template
class
SoftmaxGradCUDNNFunctor
<
double
>;
template
class
SoftmaxGradCUDNNFunctor
<
platform
::
float16
>;
// MIOPEN do not support double
#ifndef PADDLE_WITH_HIP
template
class
SoftmaxCUDNNFunctor
<
double
>;
template
class
SoftmaxGradCUDNNFunctor
<
double
>;
#endif
template
class
SoftmaxFunctor
<
platform
::
CUDADeviceContext
,
platform
::
float16
,
false
>;
template
class
SoftmaxFunctor
<
platform
::
CUDADeviceContext
,
platform
::
float16
,
...
...
paddle/fluid/operators/math/softmax.h
浏览文件 @
946dbdae
...
...
@@ -35,7 +35,7 @@ class SoftmaxGradFunctor {
framework
::
Tensor
*
x_grad
);
};
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template
<
typename
T
>
class
SoftmaxCUDNNFunctor
{
public:
...
...
paddle/fluid/operators/pool_op.h
浏览文件 @
946dbdae
...
...
@@ -22,7 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
#if
def __NVCC__
#if
defined(__HIPCC__) || defined(__NVCC__)
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#endif
...
...
python/paddle/fluid/tests/unittests/op_test.py
浏览文件 @
946dbdae
...
...
@@ -278,6 +278,9 @@ class OpTest(unittest.TestCase):
def
is_mkldnn_op_test
():
return
hasattr
(
cls
,
"use_mkldnn"
)
and
cls
.
use_mkldnn
==
True
def
is_rocm_op_test
():
return
core
.
is_compiled_with_rocm
()
if
not
hasattr
(
cls
,
"op_type"
):
raise
AssertionError
(
"This test do not have op_type in class attrs, "
...
...
@@ -298,7 +301,8 @@ class OpTest(unittest.TestCase):
and
cls
.
op_type
not
in
op_accuracy_white_list
.
NO_FP64_CHECK_GRAD_OP_LIST
\
and
not
hasattr
(
cls
,
'exist_fp64_check_grad'
)
\
and
not
is_xpu_op_test
()
\
and
not
is_mkldnn_op_test
():
and
not
is_mkldnn_op_test
()
\
and
not
is_rocm_op_test
():
raise
AssertionError
(
"This test of %s op needs check_grad with fp64 precision."
%
cls
.
op_type
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录