Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6ba0507d
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6ba0507d
编写于
3月 22, 2023
作者:
S
ShenLiang
提交者:
GitHub
3月 22, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add fused dropout add (#51752)
上级
a10718e8
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
982 addition
and
1 deletion
+982
-1
paddle/phi/api/yaml/backward.yaml
paddle/phi/api/yaml/backward.yaml
+10
-0
paddle/phi/api/yaml/ops.yaml
paddle/phi/api/yaml/ops.yaml
+10
-0
paddle/phi/infermeta/binary.cc
paddle/phi/infermeta/binary.cc
+16
-0
paddle/phi/infermeta/binary.h
paddle/phi/infermeta/binary.h
+10
-0
paddle/phi/kernels/CMakeLists.txt
paddle/phi/kernels/CMakeLists.txt
+2
-1
paddle/phi/kernels/fusion/fused_dropout_add_grad_kernel.h
paddle/phi/kernels/fusion/fused_dropout_add_grad_kernel.h
+33
-0
paddle/phi/kernels/fusion/fused_dropout_add_kernel.h
paddle/phi/kernels/fusion/fused_dropout_add_kernel.h
+55
-0
paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu
...e/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu
+241
-0
paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu
paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu
+218
-0
python/paddle/fluid/tests/unittests/test_fused_dropout_add_op.py
...paddle/fluid/tests/unittests/test_fused_dropout_add_op.py
+189
-0
python/paddle/incubate/nn/__init__.py
python/paddle/incubate/nn/__init__.py
+2
-0
python/paddle/incubate/nn/functional/__init__.py
python/paddle/incubate/nn/functional/__init__.py
+2
-0
python/paddle/incubate/nn/functional/fused_dropout_add.py
python/paddle/incubate/nn/functional/fused_dropout_add.py
+116
-0
python/paddle/incubate/nn/layer/fused_dropout_add.py
python/paddle/incubate/nn/layer/fused_dropout_add.py
+78
-0
未找到文件。
paddle/phi/api/yaml/backward.yaml
浏览文件 @
6ba0507d
...
@@ -605,6 +605,16 @@
...
@@ -605,6 +605,16 @@
kernel
:
kernel
:
func
:
frame_grad
func
:
frame_grad
-
backward_op
:
fused_dropout_add_grad
forward
:
fused_dropout_add (Tensor x, Tensor y, Scalar p, bool is_test, str mode, int seed, bool fix_seed) -> Tensor(out), Tensor(seed_offset)
args
:
(Tensor seed_offset, Tensor out_grad, Scalar p, bool is_test, str mode, bool fix_seed)
output
:
Tensor(x_grad), Tensor(y_grad)
infer_meta
:
func
:
GeneralBinaryGradInferMeta
param
:
[
out_grad
,
out_grad
]
kernel
:
func
:
fused_dropout_add_grad
-
backward_op
:
gather_nd_grad
-
backward_op
:
gather_nd_grad
forward
:
gather_nd (Tensor x, Tensor index) -> Tensor(out)
forward
:
gather_nd (Tensor x, Tensor index) -> Tensor(out)
args
:
(Tensor x, Tensor index, Tensor out_grad)
args
:
(Tensor x, Tensor index, Tensor out_grad)
...
...
paddle/phi/api/yaml/ops.yaml
浏览文件 @
6ba0507d
...
@@ -604,6 +604,16 @@
...
@@ -604,6 +604,16 @@
func
:
frame
func
:
frame
backward
:
frame_grad
backward
:
frame_grad
-
op
:
fused_dropout_add
args
:
(Tensor x, Tensor y, Scalar p, bool is_test, str mode, int seed, bool fix_seed)
output
:
Tensor(out), Tensor(seed_offset)
infer_meta
:
func
:
FusedDropoutAddInferMeta
kernel
:
func
:
fused_dropout_add
data_type
:
x
backward
:
fused_dropout_add_grad
-
op
:
fused_linear_param_grad_add
-
op
:
fused_linear_param_grad_add
args
:
(Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision =
true
)
args
:
(Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision =
true
)
output
:
Tensor(dweight_out), Tensor(dbias_out)
output
:
Tensor(dweight_out), Tensor(dbias_out)
...
...
paddle/phi/infermeta/binary.cc
浏览文件 @
6ba0507d
...
@@ -1287,6 +1287,22 @@ void FillDiagonalTensorInferMeta(const MetaTensor& x,
...
@@ -1287,6 +1287,22 @@ void FillDiagonalTensorInferMeta(const MetaTensor& x,
out
->
set_dtype
(
x
.
dtype
());
out
->
set_dtype
(
x
.
dtype
());
}
}
void
FusedDropoutAddInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
const
Scalar
&
p
,
bool
is_test
,
const
std
::
string
&
mode
,
int
seed
,
bool
fix_seed
,
MetaTensor
*
out
,
MetaTensor
*
seed_offset
)
{
out
->
share_meta
(
x
);
if
(
seed_offset
)
{
seed_offset
->
set_dims
({
2
});
seed_offset
->
set_dtype
(
DataType
::
INT64
);
}
}
// Used in FusedMatmulInferMeta
// Used in FusedMatmulInferMeta
static
std
::
vector
<
int64_t
>
GetInputShape
(
phi
::
DDim
dim
,
static
std
::
vector
<
int64_t
>
GetInputShape
(
phi
::
DDim
dim
,
std
::
vector
<
int
>
shape
,
std
::
vector
<
int
>
shape
,
...
...
paddle/phi/infermeta/binary.h
浏览文件 @
6ba0507d
...
@@ -222,6 +222,16 @@ void FillDiagonalTensorInferMeta(const MetaTensor& x,
...
@@ -222,6 +222,16 @@ void FillDiagonalTensorInferMeta(const MetaTensor& x,
int
dim2
,
int
dim2
,
MetaTensor
*
out
);
MetaTensor
*
out
);
void
FusedDropoutAddInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
const
Scalar
&
p
,
bool
is_test
,
const
std
::
string
&
mode
,
int
seed
,
bool
fix_seed
,
MetaTensor
*
out
,
MetaTensor
*
seed_offset
);
void
FusedMatmulInferMeta
(
const
MetaTensor
&
x
,
void
FusedMatmulInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
const
MetaTensor
&
y
,
const
MetaTensor
&
residual_data
,
const
MetaTensor
&
residual_data
,
...
...
paddle/phi/kernels/CMakeLists.txt
浏览文件 @
6ba0507d
...
@@ -81,11 +81,12 @@ set(COMMON_KERNEL_DEPS
...
@@ -81,11 +81,12 @@ set(COMMON_KERNEL_DEPS
utf8proc
utf8proc
gather_scatter_functor
)
gather_scatter_functor
)
set
(
COMMON_KERNEL_DEPS
${
COMMON_KERNEL_DEPS
}
process_group
)
if
(
WITH_FLASHATTN
)
if
(
WITH_FLASHATTN
)
set
(
COMMON_KERNEL_DEPS
${
COMMON_KERNEL_DEPS
}
phi_dynload_flashattn
)
set
(
COMMON_KERNEL_DEPS
${
COMMON_KERNEL_DEPS
}
phi_dynload_flashattn
)
endif
()
endif
()
set
(
COMMON_KERNEL_DEPS
${
COMMON_KERNEL_DEPS
}
process_group
)
if
(
WITH_NCCL OR WITH_RCCL
)
if
(
WITH_NCCL OR WITH_RCCL
)
set
(
COMMON_KERNEL_DEPS
${
COMMON_KERNEL_DEPS
}
process_group_nccl
set
(
COMMON_KERNEL_DEPS
${
COMMON_KERNEL_DEPS
}
process_group_nccl
nccl_comm_context
)
nccl_comm_context
)
...
...
paddle/phi/kernels/fusion/fused_dropout_add_grad_kernel.h
0 → 100644
浏览文件 @
6ba0507d
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
FusedDropoutAddGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
seed_offset
,
const
DenseTensor
&
out_grad
,
const
Scalar
&
p
,
bool
is_test
,
const
std
::
string
&
mode
,
bool
fix_seed
,
DenseTensor
*
x_grad
,
DenseTensor
*
y_grad
);
}
// namespace phi
paddle/phi/kernels/fusion/fused_dropout_add_kernel.h
0 → 100644
浏览文件 @
6ba0507d
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
FusedDropoutAddKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
Scalar
&
p
,
bool
is_test
,
const
std
::
string
&
mode
,
int
seed
,
bool
fix_seed
,
DenseTensor
*
out
,
DenseTensor
*
seed_offset
);
template
<
typename
Context
>
static
inline
std
::
vector
<
size_t
>
GetRandomCudaProp
(
int
numel
,
const
Context
&
dev_ctx
)
{
constexpr
int
kVecSize
=
funcs
::
uniform_distribution
<
float
>::
kReturnsCount
;
auto
gpu_config
=
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
numel
,
kVecSize
);
size_t
grid_size
=
gpu_config
.
GetGridSize
();
size_t
block_size
=
gpu_config
.
GetBlockSize
();
int64_t
device_id
=
dev_ctx
.
GetPlace
().
GetDeviceId
();
const
auto
&
prop
=
phi
::
backends
::
gpu
::
GetDeviceProperties
(
device_id
);
size_t
max_grid_size
=
prop
.
maxThreadsPerMultiProcessor
*
prop
.
multiProcessorCount
/
block_size
;
grid_size
=
std
::
min
(
grid_size
,
max_grid_size
);
auto
offset
=
((
numel
-
1
)
/
(
grid_size
*
block_size
*
kVecSize
)
+
1
)
*
kVecSize
;
size_t
main_offset
=
numel
/
(
block_size
*
kVecSize
)
*
(
block_size
*
kVecSize
);
return
{
grid_size
,
block_size
,
offset
,
main_offset
};
}
}
// namespace phi
paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu
0 → 100644
浏览文件 @
6ba0507d
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fusion/fused_dropout_add_grad_kernel.h"
#include "paddle/phi/kernels/fusion/fused_dropout_add_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/dropout_impl_util.h"
#include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/primitive/compute_primitives.h"
#include "paddle/phi/kernels/funcs/dropout_impl.cu.h"
static
constexpr
int
kNumCUDAThreads
=
512
;
static
constexpr
int
kNumMaxinumNumBlocks
=
4096
;
static
inline
int
NumBlocks
(
const
int
N
)
{
return
std
::
min
((
N
+
kNumCUDAThreads
-
1
)
/
kNumCUDAThreads
,
kNumMaxinumNumBlocks
);
}
namespace
phi
{
template
<
typename
T
,
typename
MT
>
__global__
void
FuseScaleAddGrad
(
const
T
*
grad
,
T
*
x
,
T
*
y
,
const
MT
factor
,
const
int64_t
limit
,
bool
upscale_in_train
)
{
CUDA_KERNEL_LOOP
(
i
,
limit
)
{
y
[
i
]
=
grad
[
i
];
x
[
i
]
=
upscale_in_train
?
grad
[
i
]
:
static_cast
<
T
>
(
static_cast
<
MT
>
(
grad
[
i
])
*
factor
);
}
}
template
<
typename
T
>
__global__
void
FuseScaleAddGradRateZero
(
const
T
*
grad
,
T
*
src
,
T
*
res
,
const
int64_t
limit
)
{
CUDA_KERNEL_LOOP
(
i
,
limit
)
{
res
[
i
]
=
grad
[
i
];
src
[
i
]
=
0
;
}
}
template
<
typename
T1
,
typename
T2
=
T1
,
typename
OutT
=
T1
>
struct
NoMaskBwFunctor
{
const
float
retain_prob_
;
using
MT
=
typename
phi
::
kps
::
details
::
MPTypeTrait
<
T1
>::
Type
;
MT
factor_
;
HOSTDEVICE
inline
NoMaskBwFunctor
(
const
float
retain_prob
)
:
retain_prob_
(
retain_prob
)
{
factor_
=
static_cast
<
MT
>
(
1.0
f
/
retain_prob_
);
}
HOSTDEVICE
inline
NoMaskBwFunctor
(
const
float
retain_prob
,
const
MT
factor
)
:
retain_prob_
(
retain_prob
),
factor_
(
factor
)
{}
HOSTDEVICE
inline
void
operator
()(
OutT
*
dst
,
const
T1
*
src_val
,
const
T2
*
rand
,
int
num
)
const
{
static
constexpr
int
kCount
=
phi
::
funcs
::
uniform_distribution
<
T2
>::
kReturnsCount
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kCount
;
i
++
)
{
dst
[
i
+
kCount
]
=
src_val
[
i
];
dst
[
i
]
=
rand
[
i
]
<
retain_prob_
?
static_cast
<
T1
>
(
static_cast
<
MT
>
(
src_val
[
i
])
*
factor_
)
:
static_cast
<
T1
>
(
0
);
}
}
};
template
<
typename
T
,
typename
Functor
>
__global__
void
VectorizedDropoutBackward
(
const
size_t
n
,
uint64_t
seed
,
T
*
src
,
T
*
res
,
const
T
*
dst
,
uint64_t
increment
,
size_t
main_offset
,
Functor
functor
)
{
size_t
idx
=
static_cast
<
size_t
>
(
BLOCK_ID_X
*
BLOCK_NUM_X
);
static
constexpr
int
kCount
=
phi
::
funcs
::
uniform_distribution
<
float
>::
kReturnsCount
;
size_t
stride
=
BLOCK_NUM_X
*
GRID_NUM_X
*
kCount
;
#ifdef PADDLE_WITH_HIP
hiprandStatePhilox4_32_10_t
state
;
hiprand_init
(
seed
,
idx
+
THREAD_ID_X
,
increment
,
&
state
);
using
SType
=
hiprandStatePhilox4_32_10_t
;
#else
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
idx
+
THREAD_ID_X
,
increment
,
&
state
);
using
SType
=
curandStatePhilox4_32_10_t
;
#endif
float
rands
[
kCount
];
T
src_res
[
kCount
*
2
];
T
res_grad
[
kCount
];
using
Rand
=
phi
::
funcs
::
uniform_distribution
<
float
>
;
using
Cast
=
kps
::
IdentityFunctor
<
T
>
;
int
deal_size
=
BLOCK_NUM_X
*
kCount
;
size_t
fix
=
idx
*
kCount
;
for
(;
fix
<
main_offset
;
fix
+=
stride
)
{
kps
::
ReadData
<
T
,
kCount
,
1
,
false
>
(
&
src_res
[
0
],
dst
,
deal_size
);
kps
::
ElementwiseRandom
<
SType
,
float
,
kCount
,
Rand
>
(
&
rands
[
0
],
Rand
(),
&
state
);
// x_grad
kps
::
OperatorTernary
<
T
,
float
,
T
,
Functor
>
(
&
src_res
[
0
],
&
src_res
[
0
],
&
rands
[
0
],
functor
,
kCount
);
kps
::
WriteData
<
T
,
kCount
,
1
,
false
>
(
src
+
fix
,
&
src_res
[
0
],
deal_size
);
// res
kps
::
ElementwiseUnary
<
T
,
T
,
kCount
,
1
,
Cast
>
(
&
res_grad
[
0
],
&
src_res
[
kCount
],
Cast
());
kps
::
WriteData
<
T
,
kCount
,
1
,
false
>
(
res
+
fix
,
&
res_grad
[
0
],
deal_size
);
if
(
fix
>
idx
*
kCount
+
1
)
{
__syncthreads
();
}
}
int
remainder
=
n
-
fix
;
if
(
remainder
>
0
)
{
kps
::
ReadData
<
T
,
kCount
,
1
,
true
>
(
&
src_res
[
0
],
dst
+
fix
,
remainder
);
kps
::
ElementwiseRandom
<
SType
,
float
,
kCount
,
Rand
>
(
&
rands
[
0
],
Rand
(),
&
state
);
// x_grad
kps
::
OperatorTernary
<
T
,
float
,
T
,
Functor
>
(
&
src_res
[
0
],
&
src_res
[
0
],
&
rands
[
0
],
functor
,
kCount
);
kps
::
WriteData
<
T
,
kCount
,
1
,
true
>
(
src
+
fix
,
&
src_res
[
0
],
remainder
);
// res
kps
::
ElementwiseUnary
<
T
,
T
,
kCount
,
1
,
Cast
>
(
&
res_grad
[
0
],
&
src_res
[
kCount
],
Cast
());
kps
::
WriteData
<
T
,
kCount
,
1
,
true
>
(
res
+
fix
,
&
res_grad
[
0
],
remainder
);
__syncthreads
();
}
}
template
<
typename
T
,
typename
Context
>
void
FusedDropoutAddGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
seed_offset
,
const
DenseTensor
&
out_grad
,
const
Scalar
&
p
,
bool
is_test
,
const
std
::
string
&
mode
,
bool
fix_seed
,
DenseTensor
*
x_grad
,
DenseTensor
*
y_grad
)
{
int64_t
numel
=
out_grad
.
numel
();
auto
stream
=
dev_ctx
.
stream
();
float
dropout_rate
=
p
.
to
<
float
>
();
bool
upscale_in_train
=
(
mode
==
"upscale_in_train"
);
const
auto
*
seed_offset_data
=
seed_offset
.
data
<
int64_t
>
();
const
uint64_t
seed_data
=
static_cast
<
uint64_t
>
(
seed_offset_data
[
0
]);
const
uint64_t
increment
=
static_cast
<
uint64_t
>
(
seed_offset_data
[
1
]);
auto
*
x_grad_data
=
dev_ctx
.
template
Alloc
<
T
>(
x_grad
);
auto
*
y_grad_data
=
dev_ctx
.
template
Alloc
<
T
>(
y_grad
);
const
auto
*
out_grad_data
=
out_grad
.
data
<
T
>
();
using
MT
=
typename
phi
::
kps
::
details
::
MPTypeTrait
<
T
>::
Type
;
int
blocks
=
NumBlocks
(
numel
);
int
threads
=
kNumCUDAThreads
;
if
(
is_test
)
{
MT
factor
=
static_cast
<
MT
>
(
1.0
f
-
dropout_rate
);
FuseScaleAddGrad
<
T
,
MT
><<<
blocks
,
threads
,
0
,
stream
>>>
(
out_grad_data
,
x_grad_data
,
y_grad_data
,
factor
,
numel
,
upscale_in_train
);
}
else
{
if
(
upscale_in_train
&&
dropout_rate
==
1.0
f
)
{
FuseScaleAddGradRateZero
<
T
><<<
blocks
,
threads
,
0
,
stream
>>>
(
out_grad_data
,
x_grad_data
,
y_grad_data
,
numel
);
return
;
}
auto
random_prop
=
GetRandomCudaProp
(
numel
,
dev_ctx
);
size_t
grid_size
=
random_prop
[
0
];
size_t
block_size
=
random_prop
[
1
];
size_t
offset
=
random_prop
[
2
];
size_t
main_offset
=
random_prop
[
3
];
auto
functor
=
upscale_in_train
?
NoMaskBwFunctor
<
T
,
float
>
(
1.0
f
-
dropout_rate
)
:
NoMaskBwFunctor
<
T
,
float
>
(
1.0
f
-
dropout_rate
,
1.0
f
);
#define PD_DROPOUT_KERNEL_NAME \
VectorizedDropoutBackward<T, NoMaskBwFunctor<T, float>>
PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL
(
!
fix_seed
,
PD_DROPOUT_KERNEL_NAME
,
grid_size
,
block_size
,
0
,
stream
,
offset
,
KERNEL_PARAMS
.
As
<
uint64_t
>
(
1
),
KERNEL_PARAMS
.
As
<
uint64_t
>
(
5
),
numel
,
seed_data
,
// need save
x_grad_data
,
y_grad_data
,
out_grad_data
,
// grad
increment
,
// need save
main_offset
,
functor
);
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
fused_dropout_add_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
FusedDropoutAddGradKernel
,
float
,
double
,
phi
::
dtype
::
bfloat16
,
phi
::
dtype
::
float16
)
{
kernel
->
InputAt
(
0
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
// seed_offset
}
paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu
0 → 100644
浏览文件 @
6ba0507d
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fusion/fused_dropout_add_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/dropout_impl_util.h"
#include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/primitive/compute_primitives.h"
#include "paddle/phi/kernels/funcs/dropout_impl.cu.h"
namespace
phi
{
template
<
typename
T1
,
typename
T2
=
T1
,
typename
OutT
=
T1
>
struct
NoMaskFwFunctor
{
const
float
retain_prob_
;
const
bool
is_upscale_in_train_
;
using
MT
=
typename
phi
::
kps
::
details
::
MPTypeTrait
<
T1
>::
Type
;
MT
factor
;
HOSTDEVICE
inline
NoMaskFwFunctor
(
const
float
retain_prob
,
const
bool
is_upscale_in_train
)
:
retain_prob_
(
retain_prob
),
is_upscale_in_train_
(
is_upscale_in_train
)
{
factor
=
static_cast
<
MT
>
(
1.0
f
/
retain_prob_
);
}
HOSTDEVICE
inline
void
operator
()(
OutT
*
dst
,
const
T1
*
src_val
,
const
T2
*
rand
,
int
num
)
const
{
static
constexpr
int
kCount
=
phi
::
funcs
::
uniform_distribution
<
T2
>::
kReturnsCount
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kCount
;
i
++
)
{
if
(
rand
[
i
]
<
retain_prob_
)
{
dst
[
i
]
=
is_upscale_in_train_
?
static_cast
<
T1
>
(
static_cast
<
MT
>
(
src_val
[
i
])
*
factor
)
:
static_cast
<
T1
>
(
src_val
[
i
]);
dst
[
i
]
+=
src_val
[
i
+
kCount
];
}
else
{
dst
[
i
]
=
src_val
[
i
+
kCount
];
}
}
}
};
template
<
typename
T
>
struct
ScaleAddFuctor
{
using
MT
=
typename
phi
::
kps
::
details
::
MPTypeTrait
<
T
>::
Type
;
explicit
ScaleAddFuctor
(
const
MT
factor
,
bool
upscale_in_train
)
:
factor_
(
factor
),
upscale_in_train_
(
upscale_in_train
)
{}
__device__
__forceinline__
T
operator
()(
const
T
src
,
const
T
res
)
const
{
return
upscale_in_train_
?
src
+
res
:
static_cast
<
T
>
(
static_cast
<
MT
>
(
src
)
*
factor_
)
+
res
;
}
private:
MT
factor_
;
bool
upscale_in_train_
;
};
template
<
typename
T
,
typename
Functor
>
__global__
void
VectorizedDropoutForward
(
const
size_t
n
,
uint64_t
seed
,
const
T
*
src
,
const
T
*
res
,
T
*
dst
,
uint64_t
increment
,
size_t
main_offset
,
Functor
functor
)
{
size_t
idx
=
static_cast
<
size_t
>
(
BLOCK_ID_X
*
BLOCK_NUM_X
);
static
constexpr
int
kCount
=
phi
::
funcs
::
uniform_distribution
<
float
>::
kReturnsCount
;
size_t
stride
=
BLOCK_NUM_X
*
GRID_NUM_X
*
kCount
;
#ifdef PADDLE_WITH_HIP
hiprandStatePhilox4_32_10_t
state
;
hiprand_init
(
seed
,
idx
+
THREAD_ID_X
,
increment
,
&
state
);
using
SType
=
hiprandStatePhilox4_32_10_t
;
#else
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
idx
+
THREAD_ID_X
,
increment
,
&
state
);
using
SType
=
curandStatePhilox4_32_10_t
;
#endif
T
dst_res
[
kCount
*
2
];
float
rands
[
kCount
];
using
Rand
=
phi
::
funcs
::
uniform_distribution
<
float
>
;
int
deal_size
=
BLOCK_NUM_X
*
kCount
;
size_t
fix
=
idx
*
kCount
;
for
(;
fix
<
main_offset
;
fix
+=
stride
)
{
kps
::
ReadData
<
T
,
kCount
,
1
,
false
>
(
&
dst_res
[
0
],
src
+
fix
,
deal_size
);
kps
::
ReadData
<
T
,
kCount
,
1
,
false
>
(
&
dst_res
[
kCount
],
res
+
fix
,
deal_size
);
kps
::
ElementwiseRandom
<
SType
,
float
,
kCount
,
Rand
>
(
&
rands
[
0
],
Rand
(),
&
state
);
// dst
kps
::
OperatorTernary
<
T
,
float
,
T
,
Functor
>
(
&
dst_res
[
0
],
&
dst_res
[
0
],
&
rands
[
0
],
functor
,
kCount
);
kps
::
WriteData
<
T
,
kCount
,
1
,
false
>
(
dst
+
fix
,
&
dst_res
[
0
],
deal_size
);
if
(
fix
>
idx
*
kCount
+
1
)
{
__syncthreads
();
}
}
int
remainder
=
n
-
fix
;
if
(
remainder
>
0
)
{
kps
::
ReadData
<
T
,
kCount
,
1
,
true
>
(
&
dst_res
[
0
],
src
+
fix
,
remainder
);
kps
::
ReadData
<
T
,
kCount
,
1
,
true
>
(
&
dst_res
[
kCount
],
res
+
fix
,
remainder
);
kps
::
ElementwiseRandom
<
SType
,
float
,
kCount
,
Rand
>
(
&
rands
[
0
],
Rand
(),
&
state
);
// dst
kps
::
OperatorTernary
<
T
,
float
,
T
,
Functor
>
(
&
dst_res
[
0
],
&
dst_res
[
0
],
&
rands
[
0
],
functor
,
kCount
);
kps
::
WriteData
<
T
,
kCount
,
1
,
true
>
(
dst
+
fix
,
&
dst_res
[
0
],
remainder
);
__syncthreads
();
}
}
template
<
typename
T
,
typename
Context
>
void
FusedDropoutAddKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
Scalar
&
p
,
bool
is_test
,
const
std
::
string
&
mode
,
int
seed
,
bool
fix_seed
,
DenseTensor
*
out
,
DenseTensor
*
seed_offset
)
{
auto
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
auto
*
seed_offset_data
=
dev_ctx
.
template
HostAlloc
<
int64_t
>(
seed_offset
);
int64_t
numel
=
x
.
numel
();
auto
stream
=
dev_ctx
.
stream
();
bool
upscale_in_train
=
(
mode
==
"upscale_in_train"
);
const
auto
*
x_data
=
x
.
data
<
T
>
();
const
auto
*
y_data
=
y
.
data
<
T
>
();
float
dropout_rate
=
p
.
to
<
float
>
();
if
(
!
is_test
)
{
if
(
dropout_rate
==
1.0
f
)
{
phi
::
Copy
(
dev_ctx
,
y
,
dev_ctx
.
GetPlace
(),
false
,
out
);
return
;
}
uint64_t
seed_data
;
uint64_t
increment
;
auto
random_prop
=
GetRandomCudaProp
(
numel
,
dev_ctx
);
size_t
grid_size
=
random_prop
[
0
];
size_t
block_size
=
random_prop
[
1
];
size_t
offset
=
random_prop
[
2
];
size_t
main_offset
=
random_prop
[
3
];
funcs
::
GetSeedDataAndIncrement
(
dev_ctx
,
nullptr
,
fix_seed
,
seed
,
offset
,
&
seed_data
,
&
increment
);
seed_offset_data
[
0
]
=
static_cast
<
int64_t
>
(
seed_data
);
seed_offset_data
[
1
]
=
static_cast
<
int64_t
>
(
increment
);
auto
dst_functor
=
NoMaskFwFunctor
<
T
,
float
>
(
1.0
f
-
dropout_rate
,
upscale_in_train
);
#define PD_DROPOUT_KERNEL_NAME \
VectorizedDropoutForward<T, NoMaskFwFunctor<T, float>>
PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL
(
!
fix_seed
,
PD_DROPOUT_KERNEL_NAME
,
grid_size
,
block_size
,
0
,
stream
,
offset
,
KERNEL_PARAMS
.
As
<
uint64_t
>
(
1
),
KERNEL_PARAMS
.
As
<
uint64_t
>
(
5
),
numel
,
seed_data
,
// need save
x_data
,
y_data
,
out_data
,
increment
,
// need save
main_offset
,
dst_functor
);
#undef PD_DROPOUT_KERNEL_NAME
}
else
{
using
MT
=
typename
phi
::
kps
::
details
::
MPTypeTrait
<
T
>::
Type
;
MT
factor
=
static_cast
<
MT
>
(
1.0
f
-
dropout_rate
);
std
::
vector
<
phi
::
DenseTensor
*>
outs
=
{
out
};
std
::
vector
<
const
phi
::
DenseTensor
*>
ins
=
{
&
x
,
&
y
};
phi
::
funcs
::
ElementwiseKernel
<
T
>
(
dev_ctx
,
ins
,
&
outs
,
ScaleAddFuctor
<
T
>
(
factor
,
upscale_in_train
));
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
fused_dropout_add
,
GPU
,
ALL_LAYOUT
,
phi
::
FusedDropoutAddKernel
,
float
,
double
,
phi
::
dtype
::
bfloat16
,
phi
::
dtype
::
float16
)
{
kernel
->
OutputAt
(
1
).
SetDataType
(
phi
::
DataType
::
INT64
);
}
python/paddle/fluid/tests/unittests/test_fused_dropout_add_op.py
0 → 100644
浏览文件 @
6ba0507d
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
paddle
import
paddle.fluid.core
as
core
from
paddle
import
fluid
from
paddle.incubate.nn.functional
import
fused_dropout_add
from
paddle.incubate.nn.layer.fused_dropout_add
import
FusedDropoutAdd
def
paddle_dropout_add
(
x
,
y
,
p
=
0.5
,
training
=
True
,
mode
=
"upscale_in_train"
):
tmp
=
paddle
.
nn
.
functional
.
dropout
(
x
,
p
,
training
=
training
,
mode
=
mode
)
return
tmp
+
y
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA "
,
)
class
TestFusedDropoutAdd
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
shape
=
(
2
,
10
,
10
,
2
)
self
.
dtype
=
'float64'
self
.
dropout_rate
=
0.9
self
.
training
=
True
self
.
mode
=
"upscale_in_train"
self
.
seed
=
1027
def
get_paddle_tensor
(
self
):
tmp
=
paddle
.
randn
(
self
.
shape
,
self
.
dtype
)
tmp
.
stop_gradient
=
False
return
tmp
def
get_forward_backward
(
self
,
dropout_add
,
seed
):
paddle
.
disable_static
()
paddle
.
seed
(
seed
)
count
=
3
data
=
[]
fw
=
[]
bw
=
[]
for
_
in
range
(
count
):
data
.
append
(
self
.
get_paddle_tensor
())
out
=
data
[
0
]
for
i
in
range
(
1
,
count
):
out
=
dropout_add
(
out
,
data
[
i
],
p
=
self
.
dropout_rate
,
training
=
self
.
training
,
mode
=
self
.
mode
,
)
fw
.
append
(
out
)
loss
=
paddle
.
mean
(
out
)
loss
.
backward
()
for
i
in
range
(
count
):
bw
.
append
(
data
[
i
].
grad
)
return
fw
,
bw
def
test_fused_dropout_add
(
self
):
p_fw
,
p_bw
=
self
.
get_forward_backward
(
paddle_dropout_add
,
seed
=
self
.
seed
)
f_fw
,
f_bw
=
self
.
get_forward_backward
(
fused_dropout_add
,
seed
=
self
.
seed
)
for
i
in
range
(
len
(
p_fw
)):
np
.
testing
.
assert_allclose
(
p_fw
[
i
].
numpy
(),
f_fw
[
i
].
numpy
(),
rtol
=
1e-05
)
np
.
testing
.
assert_allclose
(
p_bw
[
i
].
numpy
(),
f_bw
[
i
].
numpy
(),
rtol
=
1e-05
)
def
create_test_class
(
parent
,
dtype
,
mode
,
training
,
p
,
seed
):
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestFusedDropoutAddCase
(
parent
):
def
setUp
(
self
):
self
.
shape
=
(
2
,
10
,
10
,
2
)
self
.
dtype
=
dtype
self
.
dropout_rate
=
p
self
.
training
=
training
self
.
mode
=
mode
self
.
seed
=
seed
cls_name
=
"{0}_{1}_{2}_{3}_{4}_{5}"
.
format
(
parent
.
__name__
,
dtype
,
mode
,
str
(
training
),
str
(
p
),
str
(
seed
)
)
TestFusedDropoutAddCase
.
__name__
=
cls_name
globals
()[
cls_name
]
=
TestFusedDropoutAddCase
for
dtype
in
[
"float64"
,
"float32"
,
"float16"
]:
for
mode
in
[
"upscale_in_train"
,
"downscale_in_infer"
]:
for
p
in
[
0.0
,
0.5
,
0.9
,
1.0
]:
for
training
in
[
True
,
False
]:
for
seed
in
[
0
,
1024
]:
create_test_class
(
TestFusedDropoutAdd
,
dtype
,
mode
,
training
,
p
,
seed
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA "
)
class
TestFusedDropoutAddStatic
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
place
=
paddle
.
CUDAPlace
(
0
)
self
.
shape
=
(
2
,
80
,
8
,
2
)
self
.
dtype
=
'float16'
def
test_static_op
(
self
):
paddle
.
disable_static
()
paddle
.
seed
(
312
)
x_data
=
np
.
random
.
random
(
self
.
shape
)
y_data
=
np
.
random
.
random
(
self
.
shape
)
x
=
paddle
.
to_tensor
(
x_data
,
place
=
self
.
place
,
dtype
=
self
.
dtype
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
y_data
,
place
=
self
.
place
,
dtype
=
self
.
dtype
,
stop_gradient
=
False
)
out
=
fused_dropout_add
(
x
,
y
,
p
=
0.5
,
training
=
True
)
paddle
.
enable_static
()
paddle
.
seed
(
312
)
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
xs
=
paddle
.
static
.
data
(
name
=
"xs"
,
shape
=
self
.
shape
,
dtype
=
self
.
dtype
)
ys
=
paddle
.
static
.
data
(
name
=
"ys"
,
shape
=
self
.
shape
,
dtype
=
self
.
dtype
)
outs
=
fused_dropout_add
(
xs
,
ys
,
p
=
0.5
,
training
=
True
)
exe
=
fluid
.
Executor
(
self
.
place
)
out_s
=
exe
.
run
(
feed
=
{
"xs"
:
x_data
.
astype
(
'float16'
),
"ys"
:
y_data
.
astype
(
'float16'
),
},
fetch_list
=
[
outs
],
)
np
.
testing
.
assert_allclose
(
out_s
[
0
],
out
)
def
test_fused_dropout_add_layer
(
self
):
x
=
paddle
.
randn
(
self
.
shape
,
self
.
dtype
)
y
=
paddle
.
randn
(
self
.
shape
,
self
.
dtype
)
fused_d_a
=
FusedDropoutAdd
(
p
=
0.5
)
d
=
paddle
.
nn
.
Dropout
(
p
=
0.5
)
print
(
d
)
paddle
.
seed
(
2048
)
fused_out
=
fused_d_a
(
x
,
y
)
paddle
.
seed
(
2048
)
out
=
d
(
x
)
+
y
np
.
testing
.
assert_allclose
(
fused_out
,
out
)
def
test_assert
(
self
):
def
check_raise
():
x
=
paddle
.
randn
(
self
.
shape
,
self
.
dtype
)
y
=
paddle
.
randn
(
self
.
shape
,
self
.
dtype
)
fused_d_a
=
FusedDropoutAdd
(
p
=-
1
)
fused_out
=
fused_d_a
(
x
,
y
)
self
.
assertRaises
(
ValueError
,
check_raise
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/incubate/nn/__init__.py
浏览文件 @
6ba0507d
...
@@ -21,6 +21,7 @@ from .layer.fused_transformer import (
...
@@ -21,6 +21,7 @@ from .layer.fused_transformer import (
FusedBiasDropoutResidualLayerNorm
,
FusedBiasDropoutResidualLayerNorm
,
)
# noqa: F401
)
# noqa: F401
from
.layer.fused_ec_moe
import
FusedEcMoe
# noqa: F401
from
.layer.fused_ec_moe
import
FusedEcMoe
# noqa: F401
from
.layer.fused_dropout_add
import
FusedDropoutAdd
# noqa: F401
__all__
=
[
# noqa
__all__
=
[
# noqa
'FusedMultiHeadAttention'
,
'FusedMultiHeadAttention'
,
...
@@ -30,4 +31,5 @@ __all__ = [ # noqa
...
@@ -30,4 +31,5 @@ __all__ = [ # noqa
'FusedLinear'
,
'FusedLinear'
,
'FusedBiasDropoutResidualLayerNorm'
,
'FusedBiasDropoutResidualLayerNorm'
,
'FusedEcMoe'
,
'FusedEcMoe'
,
'FusedDropoutAdd'
,
]
]
python/paddle/incubate/nn/functional/__init__.py
浏览文件 @
6ba0507d
...
@@ -18,6 +18,7 @@ from .fused_transformer import fused_multi_transformer
...
@@ -18,6 +18,7 @@ from .fused_transformer import fused_multi_transformer
from
.fused_matmul_bias
import
fused_matmul_bias
,
fused_linear
from
.fused_matmul_bias
import
fused_matmul_bias
,
fused_linear
from
.fused_transformer
import
fused_bias_dropout_residual_layer_norm
from
.fused_transformer
import
fused_bias_dropout_residual_layer_norm
from
.fused_ec_moe
import
fused_ec_moe
from
.fused_ec_moe
import
fused_ec_moe
from
.fused_dropout_add
import
fused_dropout_add
__all__
=
[
__all__
=
[
'fused_multi_head_attention'
,
'fused_multi_head_attention'
,
...
@@ -27,4 +28,5 @@ __all__ = [
...
@@ -27,4 +28,5 @@ __all__ = [
'fused_linear'
,
'fused_linear'
,
'fused_bias_dropout_residual_layer_norm'
,
'fused_bias_dropout_residual_layer_norm'
,
'fused_ec_moe'
,
'fused_ec_moe'
,
'fused_dropout_add'
,
]
]
python/paddle/incubate/nn/functional/fused_dropout_add.py
0 → 100644
浏览文件 @
6ba0507d
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle
import
_C_ops
from
paddle.common_ops_import
import
default_main_program
from
paddle.fluid
import
core
from
paddle.fluid.framework
import
in_dygraph_mode
from
paddle.framework
import
LayerHelper
def
fused_dropout_add
(
x
,
y
,
p
=
0.5
,
training
=
True
,
mode
=
'upscale_in_train'
,
name
=
None
):
r
"""
Fused Dropout and Add.
Args:
x (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64.
y (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64.
p (float|int, optional): Probability of setting units to zero. Default: 0.5.
training (bool, optional): A flag indicating whether it is in train phrase or not. Default: True.
mode(str, optional): ['upscale_in_train'(default) | 'downscale_in_infer'].
1. upscale_in_train (default), upscale the output at training time
- train: :math:`out = x \times \frac{mask}{(1.0 - dropout\_prob)} + y`
- inference: :math:`out = x + y`
2. downscale_in_infer, downscale the output at inference
- train: :math:`out = input \times mask + y`
- inference: :math:`out = input \times (1.0 - dropout\_prob) + y`
name (str, optional): Name for the operation, Default: None. For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor representing the fused dropout and add, has same shape and data type as `x` .
Examples:
.. code-block:: python
# required: gpu
import paddle
from paddle.incubate.nn.functional import fused_dropout_add
x = paddle.randn([4, 10], dtype='float16')
y = paddle.randn([4, 10], dtype='float16')
out = fused_dropout_add(x, y, p=0.5)
"""
if
isinstance
(
p
,
(
int
,
float
)):
# fast return for p == 0
if
p
==
0
:
return
x
+
y
elif
p
<
0
or
p
>
1
:
raise
ValueError
(
"p argument should between 0 and 1"
)
if
mode
not
in
(
'downscale_in_infer'
,
'upscale_in_train'
):
raise
ValueError
(
"mode argument should be 'downscale_in_infer' or 'upscale_in_train'"
)
seed
=
None
if
in_dygraph_mode
():
if
default_main_program
().
random_seed
!=
0
:
seed
=
default_main_program
().
random_seed
out
,
seed_offset
=
_C_ops
.
fused_dropout_add
(
x
,
y
,
p
,
not
training
,
mode
,
seed
if
seed
is
not
None
else
0
,
seed
is
not
None
,
)
return
out
else
:
helper
=
LayerHelper
(
'fused_dropout_add'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
seed_offset
=
helper
.
create_variable_for_type_inference
(
dtype
=
core
.
VarDesc
.
VarType
.
INT64
,
stop_gradient
=
True
)
def
get_attrs
(
prog
,
dropout_prob
,
is_test
,
seed
):
if
(
seed
is
None
or
seed
==
0
)
and
prog
.
random_seed
!=
0
:
seed
=
prog
.
random_seed
attrs
=
{
'p'
:
dropout_prob
,
'is_test'
:
is_test
,
'mode'
:
mode
,
'seed'
:
seed
if
seed
is
not
None
else
0
,
'fix_seed'
:
seed
is
not
None
,
}
return
attrs
attrs
=
get_attrs
(
helper
.
main_program
,
p
,
not
training
,
seed
)
helper
.
append_op
(
type
=
'fused_dropout_add'
,
inputs
=
{
'x'
:
x
,
'y'
:
y
},
outputs
=
{
'out'
:
[
out
],
'seed_offset'
:
[
seed_offset
]},
attrs
=
attrs
,
)
return
out
python/paddle/incubate/nn/layer/fused_dropout_add.py
0 → 100644
浏览文件 @
6ba0507d
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.incubate.nn
import
functional
as
F
from
paddle.nn
import
Layer
class
FusedDropoutAdd
(
Layer
):
r
"""
Fused Dropout and Add.
Parameters:
p (float|int, optional): Probability of setting units to zero. Default: 0.5
mode(str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']
1. upscale_in_train (default), upscale the output at training time
- train: :math:`out = x \times \frac{mask}{(1.0 - p)} + y`
- inference: :math:`out = x + y`
2. downscale_in_infer, downscale the output at inference
- train: :math:`out = x \times mask + y`
- inference: :math:`out = x \times (1.0 - p) + y`
name (str, optional): Name for the operation, Default: None. For more information, please refer to :ref:`api_guide_Name`.
Shape:
- x: N-D tensor.
- y: N-D tensor.
- output: N-D tensor, the same shape as x.
Examples:
.. code-block:: python
# required: gpu
import paddle
from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd
x = paddle.to_tensor([[1,2,3], [4,5,6]], dtype="float32")
y = paddle.to_tensor([[1,2,3], [4,5,6]], dtype="float32")
m = FusedDropoutAdd(p=0.5)
out = m(x, y)
"""
def
__init__
(
self
,
p
=
0.5
,
mode
=
"upscale_in_train"
,
name
=
None
):
super
().
__init__
()
self
.
p
=
p
self
.
mode
=
mode
self
.
name
=
name
def
forward
(
self
,
x
,
y
):
out
=
F
.
fused_dropout_add
(
x
,
y
,
p
=
self
.
p
,
training
=
self
.
training
,
mode
=
self
.
mode
,
name
=
self
.
name
,
)
return
out
def
extra_repr
(
self
):
name_str
=
', name={}'
.
format
(
self
.
name
)
if
self
.
name
else
''
return
'p={}, mode={}{}'
.
format
(
self
.
p
,
self
.
mode
,
name_str
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录