Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
cf8bf032
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
cf8bf032
编写于
9月 09, 2021
作者:
Z
zhangkaihuo
提交者:
GitHub
9月 09, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add a fusion op: fused_residual_dropout_bias (#34963)
上级
eb1fbf12
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
871 addition
and
0 deletion
+871
-0
paddle/fluid/operators/fused/CMakeLists.txt
paddle/fluid/operators/fused/CMakeLists.txt
+5
-0
paddle/fluid/operators/fused/fused_dropout_common.h
paddle/fluid/operators/fused/fused_dropout_common.h
+99
-0
paddle/fluid/operators/fused/fused_dropout_test.h
paddle/fluid/operators/fused/fused_dropout_test.h
+117
-0
paddle/fluid/operators/fused/fused_residual_dropout_bias.h
paddle/fluid/operators/fused/fused_residual_dropout_bias.h
+322
-0
paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu
...fluid/operators/fused/fused_residual_dropout_bias_test.cu
+328
-0
未找到文件。
paddle/fluid/operators/fused/CMakeLists.txt
浏览文件 @
cf8bf032
...
...
@@ -71,4 +71,9 @@ if (WITH_GPU OR WITH_ROCM)
op_library
(
fused_bn_add_activation_op
)
file
(
APPEND
${
pybind_file
}
"USE_CUDA_ONLY_OP(fused_bn_add_activation);
\n
"
)
endif
()
# fused_dropout
# only support CUDA
if
(
NOT WITH_ROCM
)
nv_test
(
test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op device_context generator memory
)
endif
()
endif
()
paddle/fluid/operators/fused/fused_dropout_common.h
0 → 100644
浏览文件 @
cf8bf032
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cooperative_groups.h>
#include <cuda.h>
#include <curand_kernel.h>
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace
paddle
{
namespace
operators
{
#define CACHE_LINE 128
#define MAX_CACHE_BYTES (CACHE_LINE / CHAR_BIT)
/**
* get the threads for fused_residual_dropout_bias:
* 1D blocks: blockDim.x = cols
* 2D grids: gridDim.y = rows
*/
inline
platform
::
GpuLaunchConfig
Get1DBlocksAnd2DGrids
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
uint32_t
rows
,
const
uint32_t
cols
,
const
int
VecSize
)
{
const
uint32_t
tmp_cols
=
cols
/
VecSize
;
int
threads
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
std
::
min
(
tmp_cols
,
static_cast
<
uint32_t
>
(
ctx
.
GetMaxThreadsPerBlock
())));
const
auto
blocks_x
=
std
::
max
(
static_cast
<
uint32_t
>
(
1
),
(
tmp_cols
+
threads
-
1
)
/
threads
);
const
auto
blocks_y
=
std
::
max
(
static_cast
<
uint32_t
>
(
1
),
rows
);
platform
::
GpuLaunchConfig
config
;
config
.
block_per_grid
.
x
=
blocks_x
;
config
.
block_per_grid
.
y
=
blocks_y
;
config
.
thread_per_block
.
x
=
threads
;
return
config
;
}
__forceinline__
__device__
void
Rand1
(
curandStatePhilox4_32_10_t
*
state
,
float
*
data
)
{
data
[
0
]
=
curand_uniform
(
state
);
}
__forceinline__
__device__
void
Rand2
(
curandStatePhilox4_32_10_t
*
state
,
float
*
data
)
{
data
[
0
]
=
curand_uniform
(
state
);
data
[
1
]
=
curand_uniform
(
state
);
}
__forceinline__
__device__
void
Rand4
(
curandStatePhilox4_32_10_t
*
state
,
float
*
data
)
{
float4
rand4
=
curand_uniform4
(
state
);
data
[
0
]
=
rand4
.
x
;
data
[
1
]
=
rand4
.
y
;
data
[
2
]
=
rand4
.
w
;
data
[
3
]
=
rand4
.
z
;
}
__forceinline__
__device__
void
Rand8
(
curandStatePhilox4_32_10_t
*
state
,
float
*
data
)
{
Rand4
(
state
,
data
);
Rand4
(
state
,
data
+
4
);
}
__forceinline__
__device__
void
RandVec
(
curandStatePhilox4_32_10_t
*
state
,
float
*
data
,
const
int
VecSize
)
{
if
(
VecSize
==
1
)
{
Rand1
(
state
,
data
);
}
else
if
(
VecSize
==
2
)
{
Rand2
(
state
,
data
);
}
else
if
(
VecSize
==
4
)
{
Rand4
(
state
,
data
);
}
else
if
(
VecSize
==
8
)
{
Rand8
(
state
,
data
);
}
else
{
return
;
}
}
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/fused/fused_dropout_test.h
0 → 100644
浏览文件 @
cf8bf032
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <random>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/string/printf.h"
namespace
framework
=
paddle
::
framework
;
namespace
platform
=
paddle
::
platform
;
namespace
memory
=
paddle
::
memory
;
USE_OP
(
dropout
);
/**
* @brief call paddle dropout op
*/
template
<
typename
T
>
void
Dropout
(
const
std
::
vector
<
T
>
&
x
,
const
framework
::
DDim
&
x_dim
,
std
::
vector
<
T
>
*
out
,
std
::
vector
<
uint8_t
>
*
mask
,
const
platform
::
CUDADeviceContext
&
ctx
,
uint64_t
seed
,
float
dropout_prob
,
bool
is_upscale_in_train
,
bool
is_test
)
{
framework
::
Scope
scope
;
auto
var_x
=
scope
.
Var
(
"X"
);
auto
tensor_x
=
var_x
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
TensorFromVector
(
x
,
ctx
,
tensor_x
);
tensor_x
->
Resize
(
x_dim
);
auto
var_out
=
scope
.
Var
(
"Out"
);
auto
tensor_out
=
var_out
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
var_mask
=
scope
.
Var
(
"Mask"
);
auto
tensor_mask
=
var_mask
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
AttributeMap
attrs
;
attrs
.
insert
({
"fix_seed"
,
1
});
attrs
.
insert
({
"seed"
,
static_cast
<
int
>
(
seed
)});
attrs
.
insert
({
"dropout_prob"
,
dropout_prob
});
if
(
is_upscale_in_train
)
{
attrs
.
insert
({
"dropout_implementation"
,
std
::
string
(
"upscale_in_train"
)});
}
if
(
is_test
)
{
attrs
.
insert
({
"is_test"
,
true
});
}
auto
op
=
framework
::
OpRegistry
::
CreateOp
(
"dropout"
,
{{
"X"
,
{
"X"
}}},
{{
"Out"
,
{
"Out"
}},
{
"Mask"
,
{
"Mask"
}}},
attrs
);
op
->
Run
(
scope
,
ctx
.
GetPlace
());
framework
::
TensorToVector
<
T
>
(
*
tensor_out
,
ctx
,
out
);
if
(
!
is_test
)
{
framework
::
TensorToVector
<
uint8_t
>
(
*
tensor_mask
,
ctx
,
mask
);
}
ctx
.
Wait
();
}
/**
* @brief call paddle dropout_grad op
*/
template
<
typename
T
>
void
DropoutGrad
(
std
::
vector
<
T
>
*
dx
,
const
framework
::
DDim
&
x_dim
,
const
std
::
vector
<
T
>
&
dout
,
const
std
::
vector
<
uint8_t
>
&
mask
,
const
platform
::
CUDADeviceContext
&
ctx
,
float
dropout_prob
,
bool
is_upscale_in_train
)
{
framework
::
Scope
scope
;
const
size_t
n
=
x_dim
[
0
]
*
x_dim
[
1
];
auto
var_out
=
scope
.
Var
(
"DOut"
);
auto
tensor_out
=
var_out
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
TensorFromVector
(
dout
,
ctx
,
tensor_out
);
tensor_out
->
Resize
(
x_dim
);
auto
var_mask
=
scope
.
Var
(
"Mask"
);
auto
tensor_mask
=
var_mask
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
TensorFromVector
(
mask
,
ctx
,
tensor_mask
);
tensor_mask
->
Resize
(
x_dim
);
auto
var_dx
=
scope
.
Var
(
"DX"
);
auto
tensor_dx
=
var_dx
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
AttributeMap
attrs
;
attrs
.
insert
({
"dropout_prob"
,
dropout_prob
});
attrs
.
insert
({
"is_test"
,
false
});
if
(
is_upscale_in_train
)
{
attrs
.
insert
({
"dropout_implementation"
,
std
::
string
(
"upscale_in_train"
)});
}
else
{
attrs
.
insert
({
"dropout_implementation"
,
std
::
string
(
"downgrade_in_infer"
)});
}
auto
op
=
framework
::
OpRegistry
::
CreateOp
(
"dropout_grad"
,
{{
"Out@GRAD"
,
{
"DOut"
}},
{
"Mask"
,
{
"Mask"
}}},
{{
"X@GRAD"
,
{
"DX"
}}},
attrs
);
op
->
Run
(
scope
,
ctx
.
GetPlace
());
framework
::
TensorToVector
(
*
tensor_dx
,
ctx
,
dx
);
ctx
.
Wait
();
}
paddle/fluid/operators/fused/fused_residual_dropout_bias.h
0 → 100644
浏览文件 @
cf8bf032
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/fused/fused_dropout_common.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
namespace
paddle
{
namespace
operators
{
/**
* @brief The fused function called by every thread
* VecSize can be 1, 2, 4 or 8
*/
template
<
typename
T
,
typename
MaskType
,
int
VecSize
,
bool
ComputeLayerNorm
>
__forceinline__
__device__
void
FusedResidualDropoutBiasOneThread
(
const
int
row_id
,
const
int
col_id
,
const
int
cols
,
curandStatePhilox4_32_10_t
*
state
,
const
float
dropout_prob
,
const
T
factor
,
const
T
*
src
,
const
T
*
residual
,
const
T
*
bias
,
T
*
dst
,
MaskType
*
mask
,
const
bool
is_test
,
typename
details
::
MPTypeTrait
<
T
>::
Type
*
mean_val
,
typename
details
::
MPTypeTrait
<
T
>::
Type
*
var_val
)
{
using
LoadT
=
platform
::
AlignedVector
<
T
,
VecSize
>
;
using
StoreT
=
platform
::
AlignedVector
<
T
,
VecSize
>
;
using
MaskStoreT
=
platform
::
AlignedVector
<
MaskType
,
VecSize
>
;
using
U
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
LoadT
src_vec
;
LoadT
residual_vec
;
LoadT
bias_vec
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VecSize
;
ii
++
)
{
bias_vec
[
ii
]
=
static_cast
<
T
>
(
0
);
}
// vectorize load data from global
platform
::
Load
<
T
,
VecSize
>
(
&
src
[
row_id
*
cols
+
col_id
],
&
src_vec
);
platform
::
Load
<
T
,
VecSize
>
(
&
residual
[
row_id
*
cols
+
col_id
],
&
residual_vec
);
if
(
bias
)
{
platform
::
Load
<
T
,
VecSize
>
(
&
bias
[
col_id
],
&
bias_vec
);
}
MaskStoreT
mask_vec
;
if
(
!
is_test
)
{
float
rand
[
VecSize
];
RandVec
(
state
,
rand
,
VecSize
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VecSize
;
ii
++
)
{
mask_vec
[
ii
]
=
static_cast
<
MaskType
>
(
rand
[
ii
]
>=
dropout_prob
);
}
}
else
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VecSize
;
ii
++
)
{
mask_vec
[
ii
]
=
static_cast
<
MaskType
>
(
1
);
}
}
StoreT
dest_vec
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VecSize
;
ii
++
)
{
dest_vec
[
ii
]
=
(
src_vec
[
ii
]
+
bias_vec
[
ii
])
*
static_cast
<
T
>
(
mask_vec
[
ii
])
*
factor
+
residual_vec
[
ii
];
if
(
ComputeLayerNorm
)
{
U
tmp
=
static_cast
<
U
>
(
dest_vec
[
ii
]);
*
mean_val
+=
tmp
;
*
var_val
+=
(
tmp
*
tmp
);
}
}
// store result to global
platform
::
Store
<
T
,
VecSize
>
(
dest_vec
,
&
dst
[
row_id
*
cols
+
col_id
]);
if
(
!
is_test
)
{
platform
::
Store
<
MaskType
,
VecSize
>
(
mask_vec
,
&
mask
[
row_id
*
cols
+
col_id
]);
}
}
/**
* @brief dst = residual + dropout(src + bias);
* the src, residual, mask and dst shape is (rows, cols)
* the bias shape is (1, cols)
* is_test: only used in inference
* mask: can be null if is_test=true
*/
template
<
typename
T
,
typename
MaskType
,
int
VecSize
>
__global__
void
FusedResidualDropoutBias
(
const
size_t
rows
,
const
size_t
cols
,
uint64_t
seed
,
const
float
dropout_prob
,
const
bool
is_upscale_in_train
,
const
T
*
src
,
const
T
*
residual
,
const
T
*
bias
,
MaskType
*
mask
,
T
*
dst
,
uint64_t
increment
,
const
bool
is_test
)
{
int
col_id
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
row_id
=
blockIdx
.
y
;
int
idx
=
row_id
*
cols
+
col_id
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
idx
,
increment
,
&
state
);
T
factor
=
static_cast
<
T
>
(
1.0
f
/
(
1.0
f
-
dropout_prob
));
if
(
!
is_upscale_in_train
)
{
factor
=
static_cast
<
T
>
(
1.0
f
);
}
if
(
is_test
)
{
factor
=
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
if
(
is_upscale_in_train
)
{
factor
=
static_cast
<
T
>
(
1.0
f
);
}
}
for
(
int
r
=
row_id
;
r
<
rows
;
r
+=
blockDim
.
y
*
gridDim
.
y
)
{
for
(
int
i
=
col_id
*
VecSize
;
i
<
cols
;
i
+=
blockDim
.
x
*
gridDim
.
x
*
VecSize
)
{
FusedResidualDropoutBiasOneThread
<
T
,
MaskType
,
VecSize
,
false
>
(
r
,
i
,
cols
,
&
state
,
dropout_prob
,
factor
,
src
,
residual
,
bias
,
dst
,
mask
,
is_test
,
nullptr
,
nullptr
);
}
}
}
/**
* @brief dst = residual + dropout(src + bias);
*/
template
<
typename
T
,
typename
MaskType
>
void
LaunchResidualDropoutBias
(
const
uint32_t
rows
,
const
uint32_t
cols
,
const
int
increment
,
uint64_t
seed
,
const
float
dropout_prob
,
const
bool
is_test
,
bool
is_upscale_in_train
,
const
T
*
src
,
const
T
*
residual
,
const
T
*
bias
,
MaskType
*
mask_data
,
T
*
dst
,
const
platform
::
CUDADeviceContext
&
ctx
)
{
// dropout_prob == 1.0f
if
(
std
::
abs
(
dropout_prob
-
1.0
f
)
<
1e-5
)
{
if
(
residual
==
dst
)
return
;
auto
cuda_place
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
ctx
.
GetPlace
());
memory
::
Copy
(
cuda_place
,
dst
,
cuda_place
,
residual
,
rows
*
cols
*
sizeof
(
T
),
ctx
.
stream
());
if
(
!
is_test
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaMemsetAsync
(
mask_data
,
0
,
rows
*
cols
*
sizeof
(
MaskType
),
ctx
.
stream
()));
}
return
;
}
const
int
VecSize
=
MAX_CACHE_BYTES
/
sizeof
(
T
);
const
int
real_vec_size
=
cols
%
VecSize
==
0
?
VecSize
:
1
;
auto
config
=
Get1DBlocksAnd2DGrids
(
ctx
,
rows
,
cols
,
real_vec_size
);
if
(
cols
%
VecSize
==
0
)
{
FusedResidualDropoutBias
<
T
,
uint8_t
,
VecSize
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
ctx
.
stream
()
>>>
(
rows
,
cols
,
seed
,
dropout_prob
,
is_upscale_in_train
,
src
,
residual
,
bias
,
mask_data
,
dst
,
increment
,
is_test
);
}
else
{
FusedResidualDropoutBias
<
T
,
uint8_t
,
1
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
ctx
.
stream
()
>>>
(
rows
,
cols
,
seed
,
dropout_prob
,
is_upscale_in_train
,
src
,
residual
,
bias
,
mask_data
,
dst
,
increment
,
is_test
);
}
}
/*
* @brief calculate the grad of no bias
*/
template
<
typename
T
,
typename
MaskType
,
int
VecSize
>
__global__
void
FusedResidualDropoutGrad
(
const
T
*
dout
,
const
MaskType
*
mask
,
const
T
factor
,
const
int64_t
size
,
T
*
dx
)
{
int64_t
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
using
LoadT
=
platform
::
AlignedVector
<
T
,
VecSize
>
;
using
StoreT
=
platform
::
AlignedVector
<
T
,
VecSize
>
;
using
MaskLoadT
=
platform
::
AlignedVector
<
MaskType
,
VecSize
>
;
for
(
int
i
=
idx
*
VecSize
;
i
<
size
;
i
+=
blockDim
.
x
*
gridDim
.
x
*
VecSize
)
{
LoadT
dout_vec
;
MaskLoadT
mask_vec
;
platform
::
Load
<
T
,
VecSize
>
(
&
dout
[
i
],
&
dout_vec
);
platform
::
Load
<
MaskType
,
VecSize
>
(
&
mask
[
i
],
&
mask_vec
);
StoreT
dx_vec
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VecSize
;
ii
++
)
{
dx_vec
[
ii
]
=
dout_vec
[
ii
]
*
static_cast
<
T
>
(
mask_vec
[
ii
])
*
factor
;
}
platform
::
Store
<
T
,
VecSize
>
(
dx_vec
,
&
dx
[
i
]);
}
}
/**
* blocks(128 * 8)
* 1. calculate the dx and reduce total rows to 128 rows
* 2. save 128*8 temporary sum in 8*128 shared memory
* 3. reduce the sum of 128 rows data by 8*VecSize warps
*/
template
<
typename
T
,
typename
MaskType
,
int
BlockSizeX
,
int
BlockSizeY
,
int
VecSize
>
__global__
void
FusedResidualDropoutBiasGrad
(
const
T
*
dout
,
const
MaskType
*
mask
,
const
T
factor
,
const
int64_t
rows
,
const
int64_t
cols
,
T
*
dx
,
T
*
dbias
)
{
int64_t
col_id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
using
LoadT
=
platform
::
AlignedVector
<
T
,
VecSize
>
;
using
StoreT
=
platform
::
AlignedVector
<
T
,
VecSize
>
;
using
MaskLoadT
=
platform
::
AlignedVector
<
MaskType
,
VecSize
>
;
T
tmp_sum
[
VecSize
]
=
{
static_cast
<
T
>
(
0
)};
// calculate the dx and temporary sum
if
(
col_id
*
VecSize
<
cols
)
{
for
(
int
row_id
=
threadIdx
.
y
;
row_id
<
rows
;
row_id
+=
blockDim
.
y
)
{
int
index
=
row_id
*
cols
+
col_id
*
VecSize
;
LoadT
out_vec
;
MaskLoadT
mask_vec
;
StoreT
dx_vec
;
platform
::
Load
<
T
,
VecSize
>
(
&
dout
[
index
],
&
out_vec
);
platform
::
Load
<
MaskType
,
VecSize
>
(
&
mask
[
index
],
&
mask_vec
);
#pragma unroll
for
(
int
i
=
0
;
i
<
VecSize
;
i
++
)
{
dx_vec
[
i
]
=
out_vec
[
i
]
*
static_cast
<
T
>
(
mask_vec
[
i
])
*
factor
;
tmp_sum
[
i
]
+=
out_vec
[
i
];
}
platform
::
Store
<
T
,
VecSize
>
(
dx_vec
,
&
dx
[
index
]);
}
}
// save temporary sum to cache and do transpose
__shared__
T
cache
[
BlockSizeX
*
VecSize
][
BlockSizeY
];
for
(
int
i
=
0
;
i
<
VecSize
;
i
++
)
{
cache
[
threadIdx
.
x
*
VecSize
+
i
][
threadIdx
.
y
]
=
tmp_sum
[
i
];
}
__syncthreads
();
// reduce sum
T
sum
=
static_cast
<
T
>
(
0
);
int
tid
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
int
x
=
tid
>>
5
;
// warp id
int
y
=
tid
&
31
;
// thread id on warp 0~31
// need BlockSizeX * VecSize warps
if
(
x
<
BlockSizeX
*
VecSize
)
{
// reduce 128 to 32
#pragma unroll
for
(
int
i
=
0
;
i
<
(
BlockSizeY
>>
5
);
i
++
)
{
sum
+=
cache
[
x
][
y
+
i
*
32
];
}
}
// reduce 32 to 1
sum
=
WarpReduceSum
(
sum
);
// save sum to dbias
int
bias_id
=
blockIdx
.
x
*
blockDim
.
x
*
VecSize
+
x
;
if
(
y
==
0
&&
x
<
VecSize
*
BlockSizeX
&&
bias_id
<
cols
)
{
dbias
[
bias_id
]
=
sum
;
}
}
/**
* @brief to launch kernel FusedResidualDropoutBiasGradVec
*/
template
<
typename
T
,
typename
MaskType
>
void
LaunchResidualDropoutBiasGrad
(
const
T
*
dout
,
const
MaskType
*
mask
,
const
float
dropout_prob
,
const
bool
is_upscale_in_train
,
const
uint32_t
rows
,
const
uint32_t
cols
,
T
*
dx
,
T
*
dbias
,
const
platform
::
CUDADeviceContext
&
ctx
)
{
const
T
zero
=
static_cast
<
T
>
(
0.0
f
);
auto
factor
=
dropout_prob
==
static_cast
<
float
>
(
1.0
f
)
?
zero
:
static_cast
<
T
>
(
1.0
f
/
(
1.0
f
-
dropout_prob
));
if
(
!
is_upscale_in_train
)
{
factor
=
static_cast
<
T
>
(
1.0
f
);
}
const
int
VecSize
=
MAX_CACHE_BYTES
/
sizeof
(
T
);
int
real_vec_size
=
cols
%
VecSize
==
0
?
VecSize
:
1
;
if
(
dbias
!=
nullptr
)
{
auto
threads
=
std
::
min
(
cols
/
real_vec_size
,
static_cast
<
uint32_t
>
(
8
));
auto
blocks
=
std
::
max
((
uint32_t
)
1
,
(
cols
/
real_vec_size
+
threads
-
1
)
/
threads
);
dim3
block_dim
(
threads
,
128
,
1
);
dim3
grid_dim
(
blocks
,
1
,
1
);
if
(
cols
%
VecSize
==
0
)
{
FusedResidualDropoutBiasGrad
<
T
,
MaskType
,
8
,
128
,
VecSize
><<<
grid_dim
,
block_dim
,
0
,
ctx
.
stream
()
>>>
(
dout
,
mask
,
factor
,
rows
,
cols
,
dx
,
dbias
);
}
else
{
FusedResidualDropoutBiasGrad
<
T
,
MaskType
,
8
,
128
,
1
><<<
grid_dim
,
block_dim
,
0
,
ctx
.
stream
()
>>>
(
dout
,
mask
,
factor
,
rows
,
cols
,
dx
,
dbias
);
}
}
else
{
const
uint64_t
n
=
rows
*
cols
;
platform
::
GpuLaunchConfig
config
=
platform
::
GetGpuLaunchConfig1D
(
ctx
,
n
/
real_vec_size
);
if
(
n
%
VecSize
==
0
)
{
FusedResidualDropoutGrad
<
T
,
MaskType
,
VecSize
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
ctx
.
stream
()
>>>
(
dout
,
mask
,
factor
,
n
,
dx
);
}
else
{
FusedResidualDropoutGrad
<
T
,
MaskType
,
1
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
ctx
.
stream
()
>>>
(
dout
,
mask
,
factor
,
n
,
dx
);
}
}
}
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu
0 → 100644
浏览文件 @
cf8bf032
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <time.h>
#include <random>
#include <vector>
#include "paddle/fluid/operators/fused/fused_dropout_test.h"
#include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h"
namespace
framework
=
paddle
::
framework
;
namespace
platform
=
paddle
::
platform
;
/**
* @brief the unittest of fusedresidualdropoutbias
* 1. random input data
* 2. add bias, call paddle dropout op, add residual, and get the base result
* 3. call FusedResidualDropoutBias function get fused result
* 4. compare ther base result and fused result
*/
template
<
typename
T
>
struct
TestFusedResidualDropoutBias
{
uint32_t
rows
;
uint32_t
cols
;
uint64_t
seed
;
float
dropout_prob
;
bool
is_upscale_in_train
;
bool
is_test
;
// default false, Set to true for inference only
bool
has_bias
=
true
;
framework
::
Tensor
src
,
residual
,
bias
,
out
,
mask
;
framework
::
Tensor
dsrc
,
dbias
;
std
::
vector
<
T
>
src_vec
,
residual_vec
,
bias_vec
;
std
::
vector
<
T
>
correct_out
,
correct_dsrc
,
correct_dbias
;
std
::
vector
<
uint8_t
>
correct_mask
;
platform
::
CUDAPlace
place
;
platform
::
CUDADeviceContext
*
ctx
;
TestFusedResidualDropoutBias
()
{
rows
=
32
;
cols
=
32
;
seed
=
0
;
dropout_prob
=
0.0
;
is_upscale_in_train
=
false
;
is_test
=
false
;
has_bias
=
true
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
device_ctx
=
pool
.
Get
(
place
);
ctx
=
reinterpret_cast
<
platform
::
CUDADeviceContext
*>
(
device_ctx
);
}
TestFusedResidualDropoutBias
(
int
rows_
,
int
cols_
,
uint64_t
seed_
=
0
,
float
dropout_prob_
=
0.0
,
bool
is_upscale_in_train_
=
false
,
bool
is_test_
=
false
)
{
rows
=
rows_
;
cols
=
cols_
;
seed
=
seed_
;
dropout_prob
=
dropout_prob_
;
is_upscale_in_train
=
is_upscale_in_train_
;
is_test
=
is_test_
;
has_bias
=
true
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
device_ctx
=
pool
.
Get
(
place
);
ctx
=
reinterpret_cast
<
platform
::
CUDADeviceContext
*>
(
device_ctx
);
}
~
TestFusedResidualDropoutBias
()
{}
void
SetUp
()
{
const
int
n
=
rows
*
cols
;
correct_out
.
resize
(
n
);
correct_mask
.
resize
(
n
);
correct_dsrc
.
resize
(
n
);
correct_dbias
.
resize
(
cols
);
src_vec
.
resize
(
n
);
residual_vec
.
resize
(
n
);
bias_vec
.
resize
(
cols
);
std
::
default_random_engine
random
(
time
(
NULL
));
std
::
uniform_real_distribution
<
float
>
dis
(
0.0
,
1.0
);
for
(
int
i
=
0
;
i
<
rows
;
i
++
)
{
for
(
int
j
=
0
;
j
<
cols
;
j
++
)
{
src_vec
[
i
*
cols
+
j
]
=
static_cast
<
T
>
(
dis
(
random
));
residual_vec
[
i
*
cols
+
j
]
=
static_cast
<
T
>
(
dis
(
random
));
if
(
i
==
0
)
{
bias_vec
[
j
]
=
dis
(
random
);
}
}
}
framework
::
TensorFromVector
<
T
>
(
src_vec
,
*
ctx
,
&
src
);
src
.
Resize
({
rows
,
cols
});
framework
::
TensorFromVector
<
T
>
(
residual_vec
,
*
ctx
,
&
residual
);
residual
.
Resize
({
rows
,
cols
});
if
(
has_bias
)
{
framework
::
TensorFromVector
<
T
>
(
bias_vec
,
*
ctx
,
&
bias
);
bias
.
Resize
({
cols
});
}
{
out
.
Resize
({
rows
,
cols
});
out
.
mutable_data
<
T
>
(
place
);
mask
.
Resize
({
rows
,
cols
});
mask
.
mutable_data
<
uint8_t
>
(
place
);
dsrc
.
Resize
({
rows
,
cols
});
dsrc
.
mutable_data
<
T
>
(
place
);
if
(
has_bias
)
{
dbias
.
Resize
({
cols
});
dbias
.
mutable_data
<
T
>
(
place
);
}
}
}
void
BaseForward
()
{
std
::
vector
<
T
>
out1
(
rows
*
cols
),
out2
(
rows
*
cols
);
if
(
has_bias
)
{
// add bias
for
(
int
i
=
0
;
i
<
rows
;
i
++
)
{
for
(
int
j
=
0
;
j
<
cols
;
j
++
)
{
out1
[
i
*
cols
+
j
]
=
src_vec
[
i
*
cols
+
j
]
+
bias_vec
[
j
];
}
}
// call dropout
Dropout
<
T
>
(
out1
,
src
.
dims
(),
&
out2
,
&
correct_mask
,
*
ctx
,
seed
,
dropout_prob
,
is_upscale_in_train
,
is_test
);
}
else
{
Dropout
<
T
>
(
src_vec
,
src
.
dims
(),
&
out2
,
&
correct_mask
,
*
ctx
,
seed
,
dropout_prob
,
is_upscale_in_train
,
is_test
);
}
ctx
->
Wait
();
// add residual
for
(
int
i
=
0
;
i
<
rows
;
i
++
)
{
for
(
int
j
=
0
;
j
<
cols
;
j
++
)
{
correct_out
[
i
*
cols
+
j
]
=
residual_vec
[
i
*
cols
+
j
]
+
out2
[
i
*
cols
+
j
];
}
}
}
void
BaseBackward
()
{
DropoutGrad
<
T
>
(
&
correct_dsrc
,
src
.
dims
(),
correct_out
,
correct_mask
,
*
ctx
,
dropout_prob
,
is_upscale_in_train
);
// calc dbias
memset
(
&
correct_dbias
[
0
],
0
,
cols
*
sizeof
(
T
));
for
(
int
i
=
0
;
i
<
rows
;
i
++
)
{
for
(
int
j
=
0
;
j
<
cols
;
j
++
)
{
correct_dbias
[
j
]
+=
correct_out
[
i
*
cols
+
j
];
}
}
}
void
FusedForward
()
{
const
int
VecSize
=
MAX_CACHE_BYTES
/
sizeof
(
T
);
auto
config
=
paddle
::
operators
::
Get1DBlocksAnd2DGrids
(
*
ctx
,
(
uint64_t
)
rows
,
(
uint64_t
)
cols
,
VecSize
);
const
int
increment
=
((
cols
-
1
)
/
(
config
.
thread_per_block
.
x
*
config
.
block_per_grid
.
x
*
VecSize
)
+
1
)
*
VecSize
;
T
*
bias_ptr
=
nullptr
;
if
(
has_bias
)
{
bias_ptr
=
bias
.
data
<
T
>
();
}
paddle
::
operators
::
LaunchResidualDropoutBias
<
T
,
uint8_t
>
(
rows
,
cols
,
increment
,
seed
,
dropout_prob
,
is_test
,
is_upscale_in_train
,
src
.
data
<
T
>
(),
residual
.
data
<
T
>
(),
bias_ptr
,
mask
.
data
<
uint8_t
>
(),
out
.
data
<
T
>
(),
*
ctx
);
ctx
->
Wait
();
}
void
FusedBackward
()
{
if
(
is_test
)
{
return
;
}
T
*
bias_ptr
=
nullptr
;
if
(
has_bias
)
{
bias_ptr
=
dbias
.
data
<
T
>
();
}
paddle
::
operators
::
LaunchResidualDropoutBiasGrad
<
T
,
uint8_t
>
(
out
.
data
<
T
>
(),
mask
.
data
<
uint8_t
>
(),
dropout_prob
,
is_upscale_in_train
,
rows
,
cols
,
dsrc
.
data
<
T
>
(),
bias_ptr
,
*
ctx
);
}
void
Run
()
{
SetUp
();
BaseForward
();
FusedForward
();
BaseBackward
();
FusedBackward
();
}
void
CheckOut
(
const
T
diff
)
{
const
int
n
=
rows
*
cols
;
std
::
vector
<
T
>
_out
(
n
);
std
::
vector
<
uint8_t
>
_mask
(
n
);
framework
::
TensorToVector
(
out
,
*
ctx
,
&
_out
);
if
(
!
is_test
)
{
framework
::
TensorToVector
<
uint8_t
>
(
mask
,
*
ctx
,
&
_mask
);
}
ctx
->
Wait
();
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
EXPECT_LT
(
std
::
abs
(
_out
[
i
]
-
correct_out
[
i
]),
diff
);
if
(
!
is_test
)
EXPECT_EQ
(
_mask
[
i
],
correct_mask
[
i
]);
}
}
void
CheckGrad
(
const
T
diff
)
{
if
(
is_test
)
{
return
;
}
const
int
n
=
rows
*
cols
;
std
::
vector
<
T
>
_dsrc
(
n
);
framework
::
TensorToVector
(
dsrc
,
*
ctx
,
&
_dsrc
);
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
EXPECT_LT
(
std
::
abs
(
_dsrc
[
i
]
-
correct_dsrc
[
i
]),
diff
);
}
if
(
has_bias
)
{
std
::
vector
<
T
>
_dbias
(
cols
);
framework
::
TensorToVector
(
dbias
,
*
ctx
,
&
_dbias
);
ctx
->
Wait
();
for
(
int
i
=
0
;
i
<
cols
;
i
++
)
{
EXPECT_LT
(
std
::
abs
(
_dbias
[
i
]
-
correct_dbias
[
i
]),
diff
);
}
}
}
};
// test the shape and bias
template
<
typename
T
>
static
void
BaseTest
(
const
bool
is_fp16
=
false
)
{
const
int
rows
=
16
;
std
::
vector
<
int
>
cols_list
=
{
16
,
17
};
bool
has_bias
[
2
]
=
{
true
,
false
};
T
default_diff
=
static_cast
<
T
>
(
1e-5
);
if
(
is_fp16
)
{
default_diff
=
static_cast
<
T
>
(
1e-2
);
}
for
(
int
i
=
0
;
i
<
cols_list
.
size
();
i
++
)
{
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
TestFusedResidualDropoutBias
<
T
>
test
(
rows
,
cols_list
[
i
]);
test
.
has_bias
=
has_bias
[
j
];
test
.
Run
();
test
.
CheckOut
(
default_diff
);
if
(
!
is_fp16
)
{
test
.
CheckGrad
(
default_diff
);
}
}
}
}
TEST
(
FusedDropout
,
GPUFusedResidualDropoutBias
)
{
BaseTest
<
float
>
();
}
TEST
(
FusedDropout
,
GPUFusedResidualDropoutBiasDouble
)
{
BaseTest
<
double
>
();
}
// test fp16, For inference, check_grad is not required. ref: testdropout_op.py
TEST
(
FusedDropout
,
GPUFusedResidualDropoutBiasFp16
)
{
BaseTest
<
platform
::
float16
>
(
true
);
}
TEST
(
FusedDropout
,
GPUFusedResidualDropoutBias2
)
{
const
int
rows
=
16
;
const
int
cols
=
16
;
TestFusedResidualDropoutBias
<
float
>
test
(
rows
,
cols
,
0
,
1.0
,
false
,
false
);
test
.
Run
();
test
.
CheckOut
(
static_cast
<
float
>
(
1e-5
));
test
.
CheckGrad
(
static_cast
<
float
>
(
1e-5
));
}
TEST
(
FusedDropout
,
GPUFusedResidualDropoutBias3
)
{
const
int
rows
=
16
;
const
int
cols
=
16
;
TestFusedResidualDropoutBias
<
float
>
test
(
rows
,
cols
,
0
,
1.0
,
true
,
false
);
test
.
Run
();
test
.
CheckOut
(
static_cast
<
float
>
(
1e-5
));
test
.
CheckGrad
(
static_cast
<
float
>
(
1e-5
));
}
TEST
(
FusedDropout
,
GPUFusedResidualDropoutBias4
)
{
const
int
rows
=
16
;
const
int
cols
=
16
;
TestFusedResidualDropoutBias
<
float
>
test
(
rows
,
cols
,
0
,
0.35
,
true
,
true
);
test
.
Run
();
test
.
CheckOut
(
static_cast
<
float
>
(
1e-5
));
test
.
CheckGrad
(
static_cast
<
float
>
(
1e-5
));
}
TEST
(
FusedDropout
,
GPUFusedResidualDropoutBias5
)
{
const
int
rows
=
16
;
const
int
cols
=
16
;
TestFusedResidualDropoutBias
<
float
>
test
(
rows
,
cols
,
125
,
0.0
,
false
,
false
);
test
.
Run
();
test
.
CheckOut
(
static_cast
<
float
>
(
1e-5
));
test
.
CheckGrad
(
static_cast
<
float
>
(
1e-5
));
}
// test large shape
TEST
(
FusedDropout
,
GPUFusedResidualDropoutBias6
)
{
const
int
rows
=
256
;
const
int
cols
=
4096
;
TestFusedResidualDropoutBias
<
float
>
test
(
rows
,
cols
);
test
.
Run
();
test
.
CheckOut
(
static_cast
<
float
>
(
1e-5
));
test
.
CheckGrad
(
static_cast
<
float
>
(
1e-3
));
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录