Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
cee70434
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cee70434
编写于
9月 16, 2021
作者:
Z
zhangkaihuo
提交者:
GitHub
9月 16, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add a fusion op: fused_dropout_act_bias (#35129)
上级
bab39eb2
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
787 addition
and
119 deletion
+787
-119
paddle/fluid/operators/fused/CMakeLists.txt
paddle/fluid/operators/fused/CMakeLists.txt
+1
-0
paddle/fluid/operators/fused/fused_dropout_act_bias.h
paddle/fluid/operators/fused/fused_dropout_act_bias.h
+317
-0
paddle/fluid/operators/fused/fused_dropout_act_bias_test.cu
paddle/fluid/operators/fused/fused_dropout_act_bias_test.cu
+346
-0
paddle/fluid/operators/fused/fused_dropout_common.h
paddle/fluid/operators/fused/fused_dropout_common.h
+62
-24
paddle/fluid/operators/fused/fused_dropout_test.h
paddle/fluid/operators/fused/fused_dropout_test.h
+19
-0
paddle/fluid/operators/fused/fused_residual_dropout_bias.h
paddle/fluid/operators/fused/fused_residual_dropout_bias.h
+17
-50
paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu
...fluid/operators/fused/fused_residual_dropout_bias_test.cu
+25
-45
未找到文件。
paddle/fluid/operators/fused/CMakeLists.txt
浏览文件 @
cee70434
...
...
@@ -75,5 +75,6 @@ if (WITH_GPU OR WITH_ROCM)
# only support CUDA
if
(
NOT WITH_ROCM
)
nv_test
(
test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op device_context generator memory
)
nv_test
(
test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op device_context generator memory
)
endif
()
endif
()
paddle/fluid/operators/fused/fused_dropout_act_bias.h
0 → 100755
浏览文件 @
cee70434
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif
#include "paddle/fluid/operators/fused/fused_dropout_common.h"
#include "paddle/fluid/operators/math/functors.h"
namespace
paddle
{
namespace
operators
{
/**
*@brief the gelu functor
*/
template
<
typename
T
>
struct
GeluFunctor
{
inline
__host__
__device__
T
operator
()(
const
T
x
)
const
{
using
U
=
LayerNormParamType
<
T
>
;
const
U
casted_x
=
static_cast
<
U
>
(
x
);
const
U
temp
=
erf
(
casted_x
*
static_cast
<
U
>
(
M_SQRT1_2
));
const
U
out
=
(
casted_x
*
static_cast
<
U
>
(
0.5
)
*
(
static_cast
<
U
>
(
1
)
+
temp
));
return
static_cast
<
T
>
(
out
);
}
};
/**
*@brief the gelu grad functor
*/
template
<
typename
T
>
struct
GeluGradFunctor
{
inline
__host__
__device__
T
UseOut
(
const
T
x
)
const
{
using
U
=
LayerNormParamType
<
T
>
;
auto
casted_x
=
static_cast
<
U
>
(
x
);
auto
first
=
static_cast
<
U
>
(
0.5
)
*
(
static_cast
<
U
>
(
1
)
+
erf
(
casted_x
*
static_cast
<
U
>
(
M_SQRT1_2
)));
auto
second
=
static_cast
<
U
>
(
0.5
*
M_2_SQRTPI
*
M_SQRT1_2
)
*
casted_x
*
exp
(
-
static_cast
<
U
>
(
0.5
)
*
casted_x
*
casted_x
);
return
static_cast
<
T
>
((
first
+
second
));
}
};
/**
* @brief dst = dropout(activation(src + bias));
* the src, mask and dst shape is (rows, cols)
* the bias shape is (1, cols)
*/
template
<
typename
T
,
typename
MaskType
,
int
VecSize
,
typename
Functor
>
__global__
void
FusedDropoutActBias
(
Functor
act
,
const
uint64_t
seed
,
const
uint64_t
rows
,
const
uint64_t
cols
,
const
int
increment
,
const
float
dropout_prob
,
const
bool
is_upscale_in_train
,
const
bool
is_test
,
const
T
*
__restrict__
src
,
const
T
*
__restrict__
bias
,
T
*
dst
,
MaskType
*
mask
)
{
int
col_id
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
row_id
=
blockIdx
.
y
;
int
idx
=
row_id
*
cols
+
col_id
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
idx
,
increment
,
&
state
);
T
factor
=
static_cast
<
T
>
(
1.0
f
/
(
1.0
f
-
dropout_prob
));
if
(
!
is_upscale_in_train
)
{
factor
=
static_cast
<
T
>
(
1.0
);
}
if
(
is_test
)
{
factor
=
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
if
(
is_upscale_in_train
)
{
factor
=
static_cast
<
T
>
(
1.0
f
);
}
}
using
LoadT
=
platform
::
AlignedVector
<
T
,
VecSize
>
;
using
StoreT
=
platform
::
AlignedVector
<
T
,
VecSize
>
;
using
MaskLoadT
=
platform
::
AlignedVector
<
MaskType
,
VecSize
>
;
using
MaskStoreT
=
platform
::
AlignedVector
<
MaskType
,
VecSize
>
;
for
(
int
r
=
row_id
;
r
<
rows
;
r
+=
blockDim
.
y
*
gridDim
.
y
)
{
for
(
int
i
=
col_id
*
VecSize
;
i
<
cols
;
i
+=
blockDim
.
x
*
gridDim
.
x
*
VecSize
)
{
LoadT
src_vec
;
LoadT
bias_vec
;
// vectorize load data from global
platform
::
Load
<
T
,
VecSize
>
(
&
src
[
r
*
cols
+
i
],
&
src_vec
);
if
(
bias
)
{
platform
::
Load
<
T
,
VecSize
>
(
&
bias
[
i
],
&
bias_vec
);
}
else
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VecSize
;
ii
++
)
{
bias_vec
[
ii
]
=
static_cast
<
T
>
(
0
);
}
}
MaskStoreT
mask_vec
;
if
(
!
is_test
)
{
float
rand
[
VecSize
];
RandVec
<
VecSize
>
(
&
state
,
rand
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VecSize
;
ii
++
)
{
mask_vec
[
ii
]
=
static_cast
<
MaskType
>
(
rand
[
ii
]
>=
dropout_prob
);
}
}
else
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VecSize
;
ii
++
)
{
mask_vec
[
ii
]
=
static_cast
<
MaskType
>
(
1
);
}
}
StoreT
dest_vec
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VecSize
;
ii
++
)
{
const
T
tmp
=
src_vec
[
ii
]
+
bias_vec
[
ii
];
const
T
act_out
=
act
(
tmp
);
dest_vec
[
ii
]
=
act_out
*
static_cast
<
T
>
(
mask_vec
[
ii
])
*
factor
;
}
// store result to global
platform
::
Store
<
T
,
VecSize
>
(
dest_vec
,
&
dst
[
r
*
cols
+
i
]);
if
(
!
is_test
)
{
platform
::
Store
<
MaskType
,
VecSize
>
(
mask_vec
,
&
mask
[
r
*
cols
+
i
]);
}
}
}
}
/**
* @brief dst = dropout(activation(src + bias));
*/
template
<
typename
T
,
typename
MaskType
,
typename
Functor
>
void
LaunchDropoutActBias
(
Functor
act_functor
,
const
uint64_t
seed
,
const
uint32_t
rows
,
const
uint32_t
cols
,
const
int
increment
,
const
float
dropout_prob
,
const
bool
is_upscale_in_train
,
const
bool
is_test
,
const
T
*
src
,
const
T
*
bias
,
T
*
dst
,
MaskType
*
mask_data
,
const
platform
::
CUDADeviceContext
&
ctx
)
{
// dropout_prob == 1.0f
if
(
std
::
abs
(
dropout_prob
-
1.0
f
)
<
1e-5
)
{
SetZero
<
T
>
(
ctx
,
dst
,
rows
*
cols
);
SetZero
<
MaskType
>
(
ctx
,
mask_data
,
rows
*
cols
);
return
;
}
const
int
VecSize
=
MAX_CACHE_BYTES
/
sizeof
(
T
);
const
int
real_vec_size
=
cols
%
VecSize
==
0
?
VecSize
:
1
;
const
auto
config
=
Get1DBlocksAnd2DGrids
(
ctx
,
rows
,
cols
,
real_vec_size
);
if
(
cols
%
VecSize
==
0
)
{
FusedDropoutActBias
<
T
,
MaskType
,
VecSize
,
Functor
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
ctx
.
stream
()
>>>
(
act_functor
,
seed
,
rows
,
cols
,
increment
,
dropout_prob
,
is_upscale_in_train
,
is_test
,
src
,
bias
,
dst
,
mask_data
);
}
else
{
FusedDropoutActBias
<
T
,
MaskType
,
1
,
Functor
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
ctx
.
stream
()
>>>
(
act_functor
,
seed
,
rows
,
cols
,
increment
,
dropout_prob
,
is_upscale_in_train
,
is_test
,
src
,
bias
,
dst
,
mask_data
);
}
}
/*
* @brief calculate the grad of no bias
*/
template
<
typename
T
,
typename
MaskType
,
int
VecSize
,
typename
Functor
>
__global__
void
FusedDropoutActGrad
(
Functor
act_grad
,
const
T
*
dout
,
const
MaskType
*
mask
,
const
T
*
src
,
const
T
factor
,
const
int64_t
size
,
T
*
dx
)
{
int64_t
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
using
LoadT
=
platform
::
AlignedVector
<
T
,
VecSize
>
;
using
StoreT
=
platform
::
AlignedVector
<
T
,
VecSize
>
;
using
MaskLoadT
=
platform
::
AlignedVector
<
MaskType
,
VecSize
>
;
for
(
int
i
=
idx
*
VecSize
;
i
<
size
;
i
+=
blockDim
.
x
*
gridDim
.
x
*
VecSize
)
{
LoadT
dout_vec
;
LoadT
src_vec
;
MaskLoadT
mask_vec
;
platform
::
Load
<
T
,
VecSize
>
(
&
dout
[
i
],
&
dout_vec
);
platform
::
Load
<
MaskType
,
VecSize
>
(
&
mask
[
i
],
&
mask_vec
);
platform
::
Load
<
T
,
VecSize
>
(
&
src
[
i
],
&
src_vec
);
StoreT
dx_vec
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VecSize
;
ii
++
)
{
T
args
[
2
];
args
[
0
]
=
dout_vec
[
ii
]
*
static_cast
<
T
>
(
mask_vec
[
ii
])
*
factor
;
args
[
1
]
=
src_vec
[
ii
];
dx_vec
[
ii
]
=
args
[
0
]
*
act_grad
.
UseOut
(
args
[
1
]);
}
platform
::
Store
<
T
,
VecSize
>
(
dx_vec
,
&
dx
[
i
]);
}
}
/**
* blocks(128 * 8)
* 1. calculate the dx and reduce total rows to 128 rows
* 2. save 128*8 temporary sum in 8*128 shared memory
* 3. reduce the sum of 128 cols data by 8*VecSize warps
*/
template
<
typename
T
,
typename
MaskType
,
int
BlockSizeX
,
int
BlockSizeY
,
int
VecSize
,
typename
Functor
>
__global__
void
FusedDropoutActBiasGrad
(
Functor
act_grad
,
const
T
*
dout
,
const
MaskType
*
mask
,
const
T
*
src
,
const
T
*
bias
,
const
T
factor
,
const
int64_t
rows
,
const
int64_t
cols
,
T
*
dx
,
T
*
dbias
)
{
int64_t
col_id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
using
LoadT
=
platform
::
AlignedVector
<
T
,
VecSize
>
;
using
StoreT
=
platform
::
AlignedVector
<
T
,
VecSize
>
;
using
MaskLoadT
=
platform
::
AlignedVector
<
MaskType
,
VecSize
>
;
T
tmp_sum
[
VecSize
]
=
{
static_cast
<
T
>
(
0
)};
// calculate the dx and temporary sum
if
(
col_id
*
VecSize
<
cols
)
{
for
(
int
row_id
=
threadIdx
.
y
;
row_id
<
rows
;
row_id
+=
blockDim
.
y
)
{
int
index
=
row_id
*
cols
+
col_id
*
VecSize
;
LoadT
dout_vec
;
LoadT
src_vec
;
LoadT
bias_vec
;
MaskLoadT
mask_vec
;
platform
::
Load
<
T
,
VecSize
>
(
&
dout
[
index
],
&
dout_vec
);
platform
::
Load
<
T
,
VecSize
>
(
&
src
[
index
],
&
src_vec
);
platform
::
Load
<
MaskType
,
VecSize
>
(
&
mask
[
index
],
&
mask_vec
);
platform
::
Load
<
T
,
VecSize
>
(
&
bias
[
col_id
*
VecSize
],
&
bias_vec
);
StoreT
dx_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VecSize
;
i
++
)
{
T
val
;
T
args
[
2
];
args
[
0
]
=
dout_vec
[
i
]
*
static_cast
<
T
>
(
mask_vec
[
i
])
*
factor
;
args
[
1
]
=
src_vec
[
i
]
+
bias_vec
[
i
];
val
=
args
[
0
]
*
act_grad
.
UseOut
(
args
[
1
]);
dx_vec
[
i
]
=
val
;
tmp_sum
[
i
]
+=
val
;
}
platform
::
Store
<
T
,
VecSize
>
(
dx_vec
,
&
dx
[
index
]);
}
}
CalculateDBias
<
T
,
VecSize
,
BlockSizeX
,
BlockSizeY
>
(
tmp_sum
,
dbias
,
cols
);
}
/**
* @brief to launch kernel FusedResidualDropoutBiasGradVec
*/
template
<
typename
T
,
typename
MaskType
,
typename
Functor
>
void
LaunchDropoutActBiasGrad
(
Functor
act_functor
,
const
T
*
dout
,
const
MaskType
*
mask
,
const
T
*
src
,
const
T
*
bias
,
const
float
dropout_prob
,
const
bool
is_upscale_in_train
,
const
uint32_t
rows
,
const
uint32_t
cols
,
T
*
dx
,
T
*
dbias
,
const
platform
::
CUDADeviceContext
&
ctx
)
{
const
T
zero
=
static_cast
<
T
>
(
0.0
);
auto
factor
=
dropout_prob
==
static_cast
<
float
>
(
1.0
f
)
?
zero
:
static_cast
<
T
>
(
1.0
/
(
1.0
-
dropout_prob
));
if
(
!
is_upscale_in_train
)
{
factor
=
static_cast
<
T
>
(
1.0
f
);
}
const
int
VecSize
=
MAX_CACHE_BYTES
/
sizeof
(
T
);
int
real_vec_size
=
cols
%
VecSize
==
0
?
VecSize
:
1
;
if
(
dbias
!=
nullptr
)
{
const
auto
threads
=
8
;
const
auto
blocks
=
std
::
max
(
static_cast
<
uint32_t
>
(
1
),
(
cols
/
real_vec_size
+
threads
-
1
)
/
threads
);
dim3
block_dim
(
threads
,
128
,
1
);
dim3
grid_dim
(
blocks
,
1
,
1
);
if
(
cols
%
VecSize
==
0
)
{
FusedDropoutActBiasGrad
<
T
,
MaskType
,
8
,
128
,
VecSize
,
Functor
><<<
grid_dim
,
block_dim
,
0
,
ctx
.
stream
()
>>>
(
act_functor
,
dout
,
mask
,
src
,
bias
,
factor
,
rows
,
cols
,
dx
,
dbias
);
}
else
{
FusedDropoutActBiasGrad
<
T
,
MaskType
,
8
,
128
,
1
,
Functor
><<<
grid_dim
,
block_dim
,
0
,
ctx
.
stream
()
>>>
(
act_functor
,
dout
,
mask
,
src
,
bias
,
factor
,
rows
,
cols
,
dx
,
dbias
);
}
}
else
{
const
uint64_t
n
=
rows
*
cols
;
platform
::
GpuLaunchConfig
config
=
platform
::
GetGpuLaunchConfig1D
(
ctx
,
n
/
real_vec_size
);
if
(
n
%
VecSize
==
0
)
{
FusedDropoutActGrad
<
T
,
MaskType
,
VecSize
,
Functor
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
ctx
.
stream
()
>>>
(
act_functor
,
dout
,
mask
,
src
,
factor
,
n
,
dx
);
}
else
{
FusedDropoutActGrad
<
T
,
MaskType
,
1
,
Functor
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
ctx
.
stream
()
>>>
(
act_functor
,
dout
,
mask
,
src
,
factor
,
n
,
dx
);
}
}
}
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/fused/fused_dropout_act_bias_test.cu
0 → 100755
浏览文件 @
cee70434
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <time.h>
#include <random>
#include <vector>
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/fused/fused_dropout_act_bias.h"
#include "paddle/fluid/operators/fused/fused_dropout_test.h"
#include "paddle/fluid/operators/math/functors.h"
namespace
framework
=
paddle
::
framework
;
namespace
platform
=
paddle
::
platform
;
namespace
details
=
paddle
::
operators
::
details
;
namespace
math
=
paddle
::
operators
::
math
;
/**
* @brief the unittest of fused_dropout_act_bias
* 1. random input data
* 2. add bias, call activation, call paddle dropout, and get the base result
* 3. call FusedDropoutActBias function get fused result
* 4. compare ther base result and fused result
*/
template
<
typename
T
,
typename
Functor
,
typename
GradFunctor
>
struct
TestFusedDropoutActBias
{
uint32_t
rows
;
uint32_t
cols
;
uint64_t
seed
;
float
dropout_prob
;
bool
is_upscale_in_train
;
bool
is_test
;
// default false, Set to true for inference only
bool
has_bias
=
true
;
framework
::
Tensor
src
,
bias
,
out
,
mask
;
framework
::
Tensor
dsrc
,
dbias
;
std
::
vector
<
T
>
src_vec
,
bias_vec
,
out_vec
,
mask_vec
;
std
::
vector
<
T
>
correct_out
,
correct_dsrc
,
correct_dbias
;
std
::
vector
<
uint8_t
>
correct_mask
;
platform
::
CUDAPlace
place
;
platform
::
CUDADeviceContext
*
ctx
;
TestFusedDropoutActBias
()
{
rows
=
32
;
cols
=
32
;
seed
=
0
;
dropout_prob
=
0.0
;
is_upscale_in_train
=
false
;
is_test
=
false
;
has_bias
=
true
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
devicectx
=
pool
.
Get
(
place
);
ctx
=
reinterpret_cast
<
platform
::
CUDADeviceContext
*>
(
devicectx
);
}
TestFusedDropoutActBias
(
int
rows_
,
int
cols_
,
uint64_t
seed_
=
0
,
float
dropout_prob_
=
0.0
,
bool
is_upscale_in_train_
=
false
,
bool
is_test_
=
false
)
{
rows
=
rows_
;
cols
=
cols_
;
seed
=
seed_
;
dropout_prob
=
dropout_prob_
;
is_upscale_in_train
=
is_upscale_in_train_
;
is_test
=
is_test_
;
has_bias
=
true
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
devicectx
=
pool
.
Get
(
place
);
ctx
=
reinterpret_cast
<
platform
::
CUDADeviceContext
*>
(
devicectx
);
}
~
TestFusedDropoutActBias
()
{}
void
SetUp
()
{
const
int
n
=
rows
*
cols
;
correct_out
.
resize
(
n
);
correct_mask
.
resize
(
n
);
correct_dsrc
.
resize
(
n
);
correct_dbias
.
resize
(
cols
);
src_vec
.
resize
(
n
);
bias_vec
.
resize
(
cols
);
std
::
default_random_engine
random
(
time
(
NULL
));
std
::
uniform_real_distribution
<
float
>
dis
(
0.0
,
1.0
);
for
(
int
i
=
0
;
i
<
rows
;
i
++
)
{
for
(
int
j
=
0
;
j
<
cols
;
j
++
)
{
src_vec
[
i
*
cols
+
j
]
=
static_cast
<
T
>
(
dis
(
random
));
if
(
i
==
0
)
bias_vec
[
j
]
=
dis
(
random
);
}
}
framework
::
TensorFromVector
<
T
>
(
src_vec
,
*
ctx
,
&
src
);
src
.
Resize
({
rows
,
cols
});
if
(
has_bias
)
{
framework
::
TensorFromVector
<
T
>
(
bias_vec
,
*
ctx
,
&
bias
);
bias
.
Resize
({
cols
});
}
{
out
.
mutable_data
<
T
>
({
rows
,
cols
},
place
);
mask
.
mutable_data
<
uint8_t
>
({
rows
,
cols
},
place
);
dsrc
.
mutable_data
<
T
>
({
rows
,
cols
},
place
);
if
(
has_bias
)
{
dbias
.
mutable_data
<
T
>
({
cols
},
place
);
}
}
}
void
BaseForward
()
{
std
::
vector
<
T
>
out1
(
rows
*
cols
);
Functor
act
;
if
(
has_bias
)
{
// add bias and call activation
for
(
int
i
=
0
;
i
<
rows
;
i
++
)
{
for
(
int
j
=
0
;
j
<
cols
;
j
++
)
{
const
T
tmp
=
src_vec
[
i
*
cols
+
j
]
+
bias_vec
[
j
];
out1
[
i
*
cols
+
j
]
=
act
(
tmp
);
}
}
// call dropout
Dropout
<
T
>
(
out1
,
src
.
dims
(),
&
correct_out
,
&
correct_mask
,
*
ctx
,
seed
,
dropout_prob
,
is_upscale_in_train
,
is_test
);
}
else
{
for
(
int
i
=
0
;
i
<
rows
;
i
++
)
{
for
(
int
j
=
0
;
j
<
cols
;
j
++
)
{
const
T
tmp
=
src_vec
[
i
*
cols
+
j
];
out1
[
i
*
cols
+
j
]
=
act
(
tmp
);
}
}
Dropout
<
T
>
(
out1
,
src
.
dims
(),
&
correct_out
,
&
correct_mask
,
*
ctx
,
seed
,
dropout_prob
,
is_upscale_in_train
,
is_test
);
}
ctx
->
Wait
();
}
void
BaseBackward
()
{
std
::
vector
<
T
>
_out
(
rows
*
cols
);
// call dropout_grad
DropoutGrad
<
T
>
(
&
_out
,
src
.
dims
(),
correct_out
,
correct_mask
,
*
ctx
,
dropout_prob
,
is_upscale_in_train
);
// calculate dbias
memset
(
&
correct_dbias
[
0
],
0
,
cols
*
sizeof
(
T
));
GradFunctor
act_grad
;
for
(
int
i
=
0
;
i
<
rows
;
i
++
)
{
for
(
int
j
=
0
;
j
<
cols
;
j
++
)
{
T
args
[
2
];
args
[
0
]
=
_out
[
i
*
cols
+
j
];
if
(
has_bias
)
{
args
[
1
]
=
src_vec
[
i
*
cols
+
j
]
+
bias_vec
[
j
];
}
else
{
args
[
1
]
=
src_vec
[
i
*
cols
+
j
];
}
T
val
=
args
[
0
]
*
act_grad
.
UseOut
(
args
[
1
]);
correct_dsrc
[
i
*
cols
+
j
]
=
val
;
}
}
if
(
has_bias
)
{
// reduce_sum: keep the same calculate order as the GPU
ReduceSum
<
T
>
(
correct_dsrc
,
&
correct_dbias
,
rows
,
cols
);
}
}
void
FusedForward
()
{
const
int
VecSize
=
MAX_CACHE_BYTES
/
sizeof
(
T
);
auto
config
=
paddle
::
operators
::
Get1DBlocksAnd2DGrids
(
*
ctx
,
static_cast
<
uint64_t
>
(
rows
),
static_cast
<
uint64_t
>
(
cols
),
VecSize
);
const
int
increment
=
((
cols
-
1
)
/
(
config
.
thread_per_block
.
x
*
config
.
block_per_grid
.
x
*
VecSize
)
+
1
)
*
VecSize
;
T
*
bias_ptr
=
nullptr
;
if
(
has_bias
)
{
bias_ptr
=
bias
.
data
<
T
>
();
}
Functor
act
;
paddle
::
operators
::
LaunchDropoutActBias
<
T
,
uint8_t
,
Functor
>
(
act
,
seed
,
rows
,
cols
,
increment
,
dropout_prob
,
is_upscale_in_train
,
is_test
,
src
.
data
<
T
>
(),
bias_ptr
,
out
.
data
<
T
>
(),
mask
.
data
<
uint8_t
>
(),
*
ctx
);
ctx
->
Wait
();
}
void
FusedBackward
()
{
if
(
is_test
)
return
;
T
*
bias_ptr
=
nullptr
;
T
*
dbias_ptr
=
nullptr
;
if
(
has_bias
)
{
dbias_ptr
=
dbias
.
data
<
T
>
();
bias_ptr
=
bias
.
data
<
T
>
();
}
GradFunctor
act_grad
;
paddle
::
operators
::
LaunchDropoutActBiasGrad
<
T
,
uint8_t
,
GradFunctor
>
(
act_grad
,
out
.
data
<
T
>
(),
mask
.
data
<
uint8_t
>
(),
src
.
data
<
T
>
(),
bias_ptr
,
dropout_prob
,
is_upscale_in_train
,
rows
,
cols
,
dsrc
.
data
<
T
>
(),
dbias_ptr
,
*
ctx
);
}
void
Run
()
{
SetUp
();
BaseForward
();
FusedForward
();
BaseBackward
();
FusedBackward
();
}
void
CheckOut
(
const
T
diff
)
{
const
int
n
=
rows
*
cols
;
std
::
vector
<
T
>
_out
(
n
);
std
::
vector
<
uint8_t
>
_mask
(
n
);
framework
::
TensorToVector
(
out
,
*
ctx
,
&
_out
);
if
(
!
is_test
)
{
framework
::
TensorToVector
<
uint8_t
>
(
mask
,
*
ctx
,
&
_mask
);
}
ctx
->
Wait
();
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
EXPECT_LT
(
std
::
abs
(
_out
[
i
]
-
correct_out
[
i
]),
diff
);
if
(
!
is_test
)
EXPECT_EQ
(
_mask
[
i
],
correct_mask
[
i
]);
}
}
void
CheckGrad
(
const
T
diff
)
{
if
(
is_test
)
return
;
const
int
n
=
rows
*
cols
;
std
::
vector
<
T
>
_dsrc
(
n
);
framework
::
TensorToVector
(
dsrc
,
*
ctx
,
&
_dsrc
);
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
EXPECT_LT
(
std
::
abs
(
_dsrc
[
i
]
-
correct_dsrc
[
i
]),
diff
);
}
if
(
has_bias
)
{
std
::
vector
<
T
>
_dbias
(
cols
);
framework
::
TensorToVector
(
dbias
,
*
ctx
,
&
_dbias
);
ctx
->
Wait
();
for
(
int
i
=
0
;
i
<
cols
;
i
++
)
{
EXPECT_LT
(
std
::
abs
(
_dbias
[
i
]
-
correct_dbias
[
i
]),
diff
);
}
}
}
};
// test the shape , bias, activation
template
<
typename
T
,
typename
Functor
,
typename
GradFunctor
>
static
void
BaseTest
(
const
bool
is_fp16
=
false
)
{
const
int
rows
=
16
;
std
::
vector
<
int
>
cols_list
=
{
16
,
17
};
bool
has_bias
[
2
]
=
{
true
,
false
};
T
default_diff
=
!
is_fp16
?
static_cast
<
T
>
(
1e-5
)
:
static_cast
<
T
>
(
1e-1
);
for
(
auto
cols
:
{
16
,
17
})
{
for
(
auto
has_bias
:
{
true
,
false
})
{
TestFusedDropoutActBias
<
T
,
Functor
,
GradFunctor
>
test
(
rows
,
cols
);
test
.
has_bias
=
has_bias
;
test
.
Run
();
test
.
CheckOut
(
default_diff
);
test
.
CheckGrad
(
default_diff
);
}
}
}
TEST
(
FusedDropout
,
GPUFusedDorpoutActBias
)
{
BaseTest
<
float
,
math
::
ReluFunctor
<
float
>
,
math
::
ReluGradFunctor
<
float
>>
();
BaseTest
<
float
,
paddle
::
operators
::
GeluFunctor
<
float
>
,
paddle
::
operators
::
GeluGradFunctor
<
float
>>
();
}
TEST
(
FusedDropout
,
GPUFusedDropoutActBiasDouble
)
{
BaseTest
<
double
,
math
::
ReluFunctor
<
double
>
,
math
::
ReluGradFunctor
<
double
>>
();
BaseTest
<
double
,
paddle
::
operators
::
GeluFunctor
<
double
>
,
paddle
::
operators
::
GeluGradFunctor
<
double
>>
();
}
// test fp16, For inference, check_grad is not required. ref: test_dropout_op.py
TEST
(
FusedDropout
,
GPUFusedDropoutActBiasFp16
)
{
using
fp16
=
platform
::
float16
;
BaseTest
<
fp16
,
math
::
ReluFunctor
<
fp16
>
,
math
::
ReluGradFunctor
<
fp16
>>
(
true
);
}
TEST
(
FusedDropout
,
GPUFusedDropoutActBiasIsUpscaleInTrain
)
{
const
int
rows
=
16
;
const
int
cols
=
16
;
for
(
auto
is_upscale_in_train
:
{
true
,
false
})
{
TestFusedDropoutActBias
<
float
,
math
::
ReluFunctor
<
float
>
,
math
::
ReluGradFunctor
<
float
>>
test
(
rows
,
cols
,
0
,
1.0
,
is_upscale_in_train
,
false
);
test
.
Run
();
test
.
CheckOut
(
static_cast
<
float
>
(
1e-5
));
test
.
CheckGrad
(
static_cast
<
float
>
(
1e-3
));
}
}
TEST
(
FusedDropout
,
GPUFusedDropoutActBiasIsTest
)
{
const
int
rows
=
16
;
const
int
cols
=
16
;
TestFusedDropoutActBias
<
float
,
math
::
ReluFunctor
<
float
>
,
math
::
ReluGradFunctor
<
float
>>
test
(
rows
,
cols
,
0
,
0.35
,
true
,
true
);
test
.
Run
();
test
.
CheckOut
(
static_cast
<
float
>
(
1e-5
));
test
.
CheckGrad
(
static_cast
<
float
>
(
1e-3
));
}
TEST
(
FusedDropout
,
GPUFusedDropoutActBiasSeed
)
{
const
int
rows
=
16
;
const
int
cols
=
16
;
TestFusedDropoutActBias
<
float
,
math
::
ReluFunctor
<
float
>
,
math
::
ReluGradFunctor
<
float
>>
test
(
rows
,
cols
,
125
,
0.0
,
false
,
false
);
test
.
Run
();
test
.
CheckOut
(
static_cast
<
float
>
(
1e-5
));
test
.
CheckGrad
(
static_cast
<
float
>
(
1e-3
));
}
TEST
(
FusedDropout
,
GPUFusedDropoutActBiasLargeShape
)
{
const
int
rows
=
256
;
const
int
cols
=
4096
;
TestFusedDropoutActBias
<
float
,
math
::
ReluFunctor
<
float
>
,
math
::
ReluGradFunctor
<
float
>>
test
(
rows
,
cols
);
test
.
Run
();
test
.
CheckOut
(
static_cast
<
float
>
(
1e-5
));
test
.
CheckGrad
(
static_cast
<
float
>
(
1e-3
));
}
paddle/fluid/operators/fused/fused_dropout_common.h
浏览文件 @
cee70434
...
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/device_context.h"
...
...
@@ -39,8 +40,8 @@ namespace operators {
*/
inline
platform
::
GpuLaunchConfig
Get1DBlocksAnd2DGrids
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
uint32_t
rows
,
const
uint32_t
cols
,
const
int
VecS
ize
)
{
const
uint32_t
tmp_cols
=
cols
/
VecS
ize
;
const
uint32_t
cols
,
const
int
vec_s
ize
)
{
const
uint32_t
tmp_cols
=
cols
/
vec_s
ize
;
int
threads
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
std
::
min
(
tmp_cols
,
static_cast
<
uint32_t
>
(
ctx
.
GetMaxThreadsPerBlock
())));
...
...
@@ -54,19 +55,26 @@ inline platform::GpuLaunchConfig Get1DBlocksAnd2DGrids(
return
config
;
}
__forceinline__
__device__
void
Rand1
(
curandStatePhilox4_32_10_t
*
state
,
float
*
data
)
{
template
<
int
VecSize
>
__forceinline__
__device__
void
RandVec
(
curandStatePhilox4_32_10_t
*
state
,
float
*
data
);
template
<
>
__forceinline__
__device__
void
RandVec
<
1
>
(
curandStatePhilox4_32_10_t
*
state
,
float
*
data
)
{
data
[
0
]
=
curand_uniform
(
state
);
}
__forceinline__
__device__
void
Rand2
(
curandStatePhilox4_32_10_t
*
state
,
float
*
data
)
{
template
<
>
__forceinline__
__device__
void
RandVec
<
2
>
(
curandStatePhilox4_32_10_t
*
state
,
float
*
data
)
{
data
[
0
]
=
curand_uniform
(
state
);
data
[
1
]
=
curand_uniform
(
state
);
}
__forceinline__
__device__
void
Rand4
(
curandStatePhilox4_32_10_t
*
state
,
float
*
data
)
{
template
<
>
__forceinline__
__device__
void
RandVec
<
4
>
(
curandStatePhilox4_32_10_t
*
state
,
float
*
data
)
{
float4
rand4
=
curand_uniform4
(
state
);
data
[
0
]
=
rand4
.
x
;
data
[
1
]
=
rand4
.
y
;
...
...
@@ -74,24 +82,54 @@ __forceinline__ __device__ void Rand4(curandStatePhilox4_32_10_t *state,
data
[
3
]
=
rand4
.
z
;
}
__forceinline__
__device__
void
Rand8
(
curandStatePhilox4_32_10_t
*
state
,
float
*
data
)
{
Rand4
(
state
,
data
);
Rand4
(
state
,
data
+
4
);
template
<
>
__forceinline__
__device__
void
RandVec
<
8
>
(
curandStatePhilox4_32_10_t
*
state
,
float
*
data
)
{
RandVec
<
4
>
(
state
,
data
);
RandVec
<
4
>
(
state
,
data
+
4
);
}
__forceinline__
__device__
void
RandVec
(
curandStatePhilox4_32_10_t
*
state
,
float
*
data
,
const
int
VecSize
)
{
if
(
VecSize
==
1
)
{
Rand1
(
state
,
data
);
}
else
if
(
VecSize
==
2
)
{
Rand2
(
state
,
data
);
}
else
if
(
VecSize
==
4
)
{
Rand4
(
state
,
data
);
}
else
if
(
VecSize
==
8
)
{
Rand8
(
state
,
data
);
}
else
{
return
;
template
<
typename
T
>
inline
void
SetZero
(
const
platform
::
CUDADeviceContext
&
ctx
,
T
*
ptr
,
const
size_t
size
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaMemsetAsync
(
ptr
,
0
,
size
*
sizeof
(
T
),
ctx
.
stream
()));
}
/**
* reduce the sum of 128 cols data by 8*VecSize warps
*/
template
<
typename
T
,
int
VecSize
,
int
BlockSizeX
,
int
BlockSizeY
>
inline
__device__
void
CalculateDBias
(
const
T
*
tmp_sum
,
T
*
dbias
,
const
int
cols
)
{
// save temporary sum to cache and do transpose
__shared__
T
cache
[
BlockSizeX
*
VecSize
][
BlockSizeY
];
for
(
int
i
=
0
;
i
<
VecSize
;
i
++
)
{
cache
[
threadIdx
.
x
*
VecSize
+
i
][
threadIdx
.
y
]
=
tmp_sum
[
i
];
}
__syncthreads
();
// reduce sum
T
sum
=
static_cast
<
T
>
(
0
);
int
tid
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
int
x
=
tid
>>
5
;
// warp id
int
y
=
tid
&
31
;
// thread id on warp 0~31
// need BlockSizeX * VecSize warps
if
(
x
<
BlockSizeX
*
VecSize
)
{
// reduce 128 to 32
#pragma unroll
for
(
int
i
=
0
;
i
<
(
BlockSizeY
>>
5
);
i
++
)
{
sum
+=
cache
[
x
][
y
+
i
*
32
];
}
}
// reduce 32 to 1
sum
=
WarpReduceSum
(
sum
);
// save sum to dbias
int
bias_id
=
blockIdx
.
x
*
blockDim
.
x
*
VecSize
+
x
;
if
(
y
==
0
&&
x
<
VecSize
*
BlockSizeX
&&
bias_id
<
cols
)
{
dbias
[
bias_id
]
=
sum
;
}
}
...
...
paddle/fluid/operators/fused/fused_dropout_test.h
浏览文件 @
cee70434
...
...
@@ -115,3 +115,22 @@ void DropoutGrad(std::vector<T> *dx, const framework::DDim &x_dim,
framework
::
TensorToVector
(
*
tensor_dx
,
ctx
,
dx
);
ctx
.
Wait
();
}
template
<
typename
T
>
inline
void
ReduceSum
(
const
std
::
vector
<
T
>
&
dout
,
std
::
vector
<
T
>
*
dbias
,
const
int
rows
,
const
int
cols
)
{
for
(
int
j
=
0
;
j
<
cols
;
j
++
)
{
std
::
vector
<
T
>
tmp_dbias
(
rows
);
for
(
int
i
=
0
;
i
<
rows
;
i
++
)
{
tmp_dbias
[
i
]
=
dout
[
i
*
cols
+
j
];
}
int
tmp_rows
=
rows
/
2
;
while
(
tmp_rows
)
{
for
(
int
i
=
0
;
i
<
tmp_rows
;
i
++
)
{
tmp_dbias
[
i
]
+=
tmp_dbias
[
i
+
tmp_rows
];
}
tmp_rows
/=
2
;
}
(
*
dbias
)[
j
]
=
tmp_dbias
[
0
];
}
}
paddle/fluid/operators/fused/fused_residual_dropout_bias.h
浏览文件 @
cee70434
...
...
@@ -15,7 +15,6 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/operators/fused/fused_dropout_common.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -28,8 +27,9 @@ template <typename T, typename MaskType, int VecSize, bool ComputeLayerNorm>
__forceinline__
__device__
void
FusedResidualDropoutBiasOneThread
(
const
int
row_id
,
const
int
col_id
,
const
int
cols
,
curandStatePhilox4_32_10_t
*
state
,
const
float
dropout_prob
,
const
T
factor
,
const
T
*
src
,
const
T
*
residual
,
const
T
*
bias
,
T
*
dst
,
MaskType
*
mask
,
const
bool
is_test
,
typename
details
::
MPTypeTrait
<
T
>::
Type
*
mean_val
,
const
T
*
__restrict__
src
,
const
T
*
__restrict__
residual
,
const
T
*
__restrict__
bias
,
T
*
dst
,
MaskType
*
mask
,
const
bool
is_test
,
typename
details
::
MPTypeTrait
<
T
>::
Type
*
mean_val
,
typename
details
::
MPTypeTrait
<
T
>::
Type
*
var_val
)
{
using
LoadT
=
platform
::
AlignedVector
<
T
,
VecSize
>
;
using
StoreT
=
platform
::
AlignedVector
<
T
,
VecSize
>
;
...
...
@@ -54,7 +54,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
MaskStoreT
mask_vec
;
if
(
!
is_test
)
{
float
rand
[
VecSize
];
RandVec
(
state
,
rand
,
VecSize
);
RandVec
<
VecSize
>
(
state
,
rand
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VecSize
;
ii
++
)
{
mask_vec
[
ii
]
=
static_cast
<
MaskType
>
(
rand
[
ii
]
>=
dropout_prob
);
...
...
@@ -97,24 +97,21 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
template
<
typename
T
,
typename
MaskType
,
int
VecSize
>
__global__
void
FusedResidualDropoutBias
(
const
size_t
rows
,
const
size_t
cols
,
uint64_t
seed
,
const
float
dropout_prob
,
const
bool
is_upscale_in_train
,
const
T
*
src
,
const
T
*
residual
,
const
T
*
bias
,
MaskType
*
mask
,
T
*
dst
,
uint64_t
increment
,
const
bool
is_test
)
{
const
float
dropout_prob
,
const
bool
is_upscale_in_train
,
const
T
*
__restrict__
src
,
const
T
*
__restrict__
residual
,
const
T
*
__restrict__
bias
,
MaskType
*
mask
,
T
*
dst
,
uint64_t
increment
,
const
bool
is_test
)
{
int
col_id
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
row_id
=
blockIdx
.
y
;
int
idx
=
row_id
*
cols
+
col_id
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
idx
,
increment
,
&
state
);
T
factor
=
static_cast
<
T
>
(
1.0
f
/
(
1.0
f
-
dropout_prob
));
if
(
!
is_upscale_in_train
)
{
factor
=
static_cast
<
T
>
(
1.0
f
);
}
T
factor
=
is_upscale_in_train
?
static_cast
<
T
>
(
1.0
f
/
(
1.0
f
-
dropout_prob
))
:
static_cast
<
T
>
(
1.0
f
);
if
(
is_test
)
{
factor
=
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
if
(
is_upscale_in_train
)
{
factor
=
static_cast
<
T
>
(
1.0
f
);
}
factor
=
is_upscale_in_train
?
static_cast
<
T
>
(
1.0
f
)
:
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
for
(
int
r
=
row_id
;
r
<
rows
;
r
+=
blockDim
.
y
*
gridDim
.
y
)
{
for
(
int
i
=
col_id
*
VecSize
;
i
<
cols
;
...
...
@@ -144,8 +141,7 @@ void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols,
memory
::
Copy
(
cuda_place
,
dst
,
cuda_place
,
residual
,
rows
*
cols
*
sizeof
(
T
),
ctx
.
stream
());
if
(
!
is_test
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaMemsetAsync
(
mask_data
,
0
,
rows
*
cols
*
sizeof
(
MaskType
),
ctx
.
stream
()));
SetZero
<
MaskType
>
(
ctx
,
mask_data
,
rows
*
cols
);
}
return
;
}
...
...
@@ -234,36 +230,7 @@ __global__ void FusedResidualDropoutBiasGrad(const T *dout,
}
}
// save temporary sum to cache and do transpose
__shared__
T
cache
[
BlockSizeX
*
VecSize
][
BlockSizeY
];
for
(
int
i
=
0
;
i
<
VecSize
;
i
++
)
{
cache
[
threadIdx
.
x
*
VecSize
+
i
][
threadIdx
.
y
]
=
tmp_sum
[
i
];
}
__syncthreads
();
// reduce sum
T
sum
=
static_cast
<
T
>
(
0
);
int
tid
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
int
x
=
tid
>>
5
;
// warp id
int
y
=
tid
&
31
;
// thread id on warp 0~31
// need BlockSizeX * VecSize warps
if
(
x
<
BlockSizeX
*
VecSize
)
{
// reduce 128 to 32
#pragma unroll
for
(
int
i
=
0
;
i
<
(
BlockSizeY
>>
5
);
i
++
)
{
sum
+=
cache
[
x
][
y
+
i
*
32
];
}
}
// reduce 32 to 1
sum
=
WarpReduceSum
(
sum
);
// save sum to dbias
int
bias_id
=
blockIdx
.
x
*
blockDim
.
x
*
VecSize
+
x
;
if
(
y
==
0
&&
x
<
VecSize
*
BlockSizeX
&&
bias_id
<
cols
)
{
dbias
[
bias_id
]
=
sum
;
}
CalculateDBias
<
T
,
VecSize
,
BlockSizeX
,
BlockSizeY
>
(
tmp_sum
,
dbias
,
cols
);
}
/**
...
...
@@ -287,9 +254,9 @@ void LaunchResidualDropoutBiasGrad(const T *dout, const MaskType *mask,
const
int
VecSize
=
MAX_CACHE_BYTES
/
sizeof
(
T
);
int
real_vec_size
=
cols
%
VecSize
==
0
?
VecSize
:
1
;
if
(
dbias
!=
nullptr
)
{
auto
threads
=
std
::
min
(
cols
/
real_vec_size
,
static_cast
<
uint32_t
>
(
8
))
;
auto
blocks
=
std
::
max
((
uint32_t
)
1
,
(
cols
/
real_vec_size
+
threads
-
1
)
/
threads
);
const
auto
threads
=
8
;
auto
blocks
=
std
::
max
(
static_cast
<
uint32_t
>
(
1
),
(
cols
/
real_vec_size
+
threads
-
1
)
/
threads
);
dim3
block_dim
(
threads
,
128
,
1
);
dim3
grid_dim
(
blocks
,
1
,
1
);
if
(
cols
%
VecSize
==
0
)
{
...
...
paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu
浏览文件 @
cee70434
...
...
@@ -114,16 +114,12 @@ struct TestFusedResidualDropoutBias {
}
{
out
.
Resize
({
rows
,
cols
});
out
.
mutable_data
<
T
>
(
place
);
mask
.
Resize
({
rows
,
cols
});
mask
.
mutable_data
<
uint8_t
>
(
place
);
dsrc
.
Resize
({
rows
,
cols
});
dsrc
.
mutable_data
<
T
>
(
place
);
out
.
mutable_data
<
T
>
({
rows
,
cols
},
place
);
mask
.
mutable_data
<
uint8_t
>
({
rows
,
cols
},
place
);
dsrc
.
mutable_data
<
T
>
({
rows
,
cols
},
place
);
if
(
has_bias
)
{
dbias
.
Resize
({
cols
});
dbias
.
mutable_data
<
T
>
(
place
);
dbias
.
mutable_data
<
T
>
({
cols
},
place
);
}
}
}
...
...
@@ -159,17 +155,16 @@ struct TestFusedResidualDropoutBias {
dropout_prob
,
is_upscale_in_train
);
// calc dbias
memset
(
&
correct_dbias
[
0
],
0
,
cols
*
sizeof
(
T
));
for
(
int
i
=
0
;
i
<
rows
;
i
++
)
{
for
(
int
j
=
0
;
j
<
cols
;
j
++
)
{
correct_dbias
[
j
]
+=
correct_out
[
i
*
cols
+
j
];
}
if
(
has_bias
)
{
ReduceSum
<
T
>
(
correct_out
,
&
correct_dbias
,
rows
,
cols
);
}
}
void
FusedForward
()
{
const
int
VecSize
=
MAX_CACHE_BYTES
/
sizeof
(
T
);
auto
config
=
paddle
::
operators
::
Get1DBlocksAnd2DGrids
(
*
ctx
,
(
uint64_t
)
rows
,
(
uint64_t
)
cols
,
VecSize
);
*
ctx
,
static_cast
<
uint64_t
>
(
rows
),
static_cast
<
uint64_t
>
(
cols
),
VecSize
);
const
int
increment
=
((
cols
-
1
)
/
(
config
.
thread_per_block
.
x
*
config
.
block_per_grid
.
x
*
VecSize
)
+
1
)
*
...
...
@@ -253,21 +248,14 @@ struct TestFusedResidualDropoutBias {
template
<
typename
T
>
static
void
BaseTest
(
const
bool
is_fp16
=
false
)
{
const
int
rows
=
16
;
std
::
vector
<
int
>
cols_list
=
{
16
,
17
};
bool
has_bias
[
2
]
=
{
true
,
false
};
T
default_diff
=
static_cast
<
T
>
(
1e-5
);
if
(
is_fp16
)
{
default_diff
=
static_cast
<
T
>
(
1e-2
);
}
for
(
int
i
=
0
;
i
<
cols_list
.
size
();
i
++
)
{
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
TestFusedResidualDropoutBias
<
T
>
test
(
rows
,
cols_list
[
i
]);
test
.
has_bias
=
has_bias
[
j
];
T
default_diff
=
!
is_fp16
?
static_cast
<
T
>
(
1e-5
)
:
static_cast
<
T
>
(
1e-1
);
for
(
auto
cols
:
{
16
,
17
})
{
for
(
auto
has_bias
:
{
true
,
false
})
{
TestFusedResidualDropoutBias
<
T
>
test
(
rows
,
cols
);
test
.
has_bias
=
has_bias
;
test
.
Run
();
test
.
CheckOut
(
default_diff
);
if
(
!
is_fp16
)
{
test
.
CheckGrad
(
default_diff
);
}
test
.
CheckGrad
(
default_diff
);
}
}
}
...
...
@@ -276,30 +264,23 @@ TEST(FusedDropout, GPUFusedResidualDropoutBias) { BaseTest<float>(); }
TEST
(
FusedDropout
,
GPUFusedResidualDropoutBiasDouble
)
{
BaseTest
<
double
>
();
}
// test fp16, For inference, check_grad is not required. ref: testdropout_op.py
TEST
(
FusedDropout
,
GPUFusedResidualDropoutBiasFp16
)
{
BaseTest
<
platform
::
float16
>
(
true
);
}
TEST
(
FusedDropout
,
GPUFusedResidualDropoutBias
2
)
{
TEST
(
FusedDropout
,
GPUFusedResidualDropoutBias
IsUpscaleInTrain
)
{
const
int
rows
=
16
;
const
int
cols
=
16
;
TestFusedResidualDropoutBias
<
float
>
test
(
rows
,
cols
,
0
,
1.0
,
false
,
false
);
test
.
Run
();
test
.
CheckOut
(
static_cast
<
float
>
(
1e-5
));
test
.
CheckGrad
(
static_cast
<
float
>
(
1e-5
));
}
TEST
(
FusedDropout
,
GPUFusedResidualDropoutBias3
)
{
const
int
rows
=
16
;
const
int
cols
=
16
;
TestFusedResidualDropoutBias
<
float
>
test
(
rows
,
cols
,
0
,
1.0
,
true
,
false
);
test
.
Run
();
test
.
CheckOut
(
static_cast
<
float
>
(
1e-5
));
test
.
CheckGrad
(
static_cast
<
float
>
(
1e-5
));
for
(
auto
is_upscale_in_train
:
{
true
,
false
})
{
TestFusedResidualDropoutBias
<
float
>
test
(
rows
,
cols
,
0
,
1.0
,
is_upscale_in_train
,
false
);
test
.
Run
();
test
.
CheckOut
(
static_cast
<
float
>
(
1e-5
));
test
.
CheckGrad
(
static_cast
<
float
>
(
1e-5
));
}
}
TEST
(
FusedDropout
,
GPUFusedResidualDropoutBias
4
)
{
TEST
(
FusedDropout
,
GPUFusedResidualDropoutBias
IsTest
)
{
const
int
rows
=
16
;
const
int
cols
=
16
;
TestFusedResidualDropoutBias
<
float
>
test
(
rows
,
cols
,
0
,
0.35
,
true
,
true
);
...
...
@@ -308,7 +289,7 @@ TEST(FusedDropout, GPUFusedResidualDropoutBias4) {
test
.
CheckGrad
(
static_cast
<
float
>
(
1e-5
));
}
TEST
(
FusedDropout
,
GPUFusedResidualDropoutBias
5
)
{
TEST
(
FusedDropout
,
GPUFusedResidualDropoutBias
Seed
)
{
const
int
rows
=
16
;
const
int
cols
=
16
;
TestFusedResidualDropoutBias
<
float
>
test
(
rows
,
cols
,
125
,
0.0
,
false
,
false
);
...
...
@@ -317,8 +298,7 @@ TEST(FusedDropout, GPUFusedResidualDropoutBias5) {
test
.
CheckGrad
(
static_cast
<
float
>
(
1e-5
));
}
// test large shape
TEST
(
FusedDropout
,
GPUFusedResidualDropoutBias6
)
{
TEST
(
FusedDropout
,
GPUFusedResidualDropoutBiasLargeShape
)
{
const
int
rows
=
256
;
const
int
cols
=
4096
;
TestFusedResidualDropoutBias
<
float
>
test
(
rows
,
cols
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录