Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c171eca2
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c171eca2
编写于
9月 03, 2021
作者:
Y
Yiqun Liu
提交者:
GitHub
9月 03, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Unify the implementation of AlignedVector and simplify the codes of dropout and cast. (#35373)
上级
a9dfebb9
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
143 addition
and
172 deletion
+143
-172
paddle/fluid/operators/cast_op.cu
paddle/fluid/operators/cast_op.cu
+11
-27
paddle/fluid/operators/dropout_op.cu
paddle/fluid/operators/dropout_op.cu
+30
-40
paddle/fluid/operators/dropout_op.h
paddle/fluid/operators/dropout_op.h
+13
-31
paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
...fluid/operators/elementwise/elementwise_op_broadcast.cu.h
+3
-3
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
+4
-4
paddle/fluid/operators/fused/attn_bias_add.cu.h
paddle/fluid/operators/fused/attn_bias_add.cu.h
+3
-26
paddle/fluid/platform/aligned_vector.h
paddle/fluid/platform/aligned_vector.h
+77
-0
paddle/fluid/platform/fast_divmod.h
paddle/fluid/platform/fast_divmod.h
+2
-41
未找到文件。
paddle/fluid/operators/cast_op.cu
浏览文件 @
c171eca2
...
...
@@ -13,47 +13,31 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace
paddle
{
namespace
operators
{
// aligned vector generates vectorized load/store on CUDA
template
<
typename
T
,
int
Size
>
struct
alignas
(
sizeof
(
T
)
*
Size
)
AlignedVector
{
T
val
[
Size
];
};
template
<
typename
T
>
inline
int
VectorizedSize
(
const
T
*
pointer
)
{
uint64_t
address
=
reinterpret_cast
<
uint64_t
>
(
pointer
);
constexpr
int
vec4
=
std
::
alignment_of
<
AlignedVector
<
T
,
4
>>::
value
;
// NOLINT
if
(
address
%
vec4
==
0
)
{
return
4
;
}
return
1
;
}
template
<
typename
InT
,
typename
OutT
,
int
VecSize
>
__global__
void
VecCastCUDAKernel
(
const
InT
*
in
,
const
int64_t
N
,
OutT
*
out
)
{
using
LoadT
=
platform
::
AlignedVector
<
InT
,
VecSize
>
;
using
StoreT
=
platform
::
AlignedVector
<
OutT
,
VecSize
>
;
int64_t
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
using
LoadT
=
AlignedVector
<
InT
,
VecSize
>
;
using
StoreT
=
AlignedVector
<
OutT
,
VecSize
>
;
for
(
int64_t
i
=
idx
*
VecSize
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
*
VecSize
)
{
InT
in_vec
[
VecSize
];
LoadT
*
in_value
=
reinterpret_cast
<
LoadT
*>
(
&
in_vec
);
*
in_value
=
*
reinterpret_cast
<
const
LoadT
*>
(
&
in
[
i
]);
LoadT
in_val
;
platform
::
Load
<
InT
,
VecSize
>
(
&
in
[
i
],
&
in_val
);
OutT
out_vec
[
VecSize
]
;
StoreT
out_val
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VecSize
;
ii
++
)
{
out_v
ec
[
ii
]
=
static_cast
<
OutT
>
(
in_vec
[
ii
]);
for
(
int
j
=
0
;
j
<
VecSize
;
j
++
)
{
out_v
al
[
j
]
=
static_cast
<
OutT
>
(
in_val
[
j
]);
}
*
(
reinterpret_cast
<
StoreT
*>
(
&
out
[
i
]))
=
*
reinterpret_cast
<
StoreT
*>
(
&
out_vec
[
0
]);
platform
::
Store
<
OutT
,
VecSize
>
(
out_val
,
&
out
[
i
]);
}
}
...
...
@@ -78,7 +62,7 @@ struct CastOpFunctor<platform::CUDADeviceContext, InT> {
auto
*
out
=
out_
->
mutable_data
<
OutT
>
(
ctx_
.
GetPlace
());
platform
::
GpuLaunchConfig
config
=
platform
::
GetGpuLaunchConfig1D
(
ctx_
,
size
);
int
vec_size
=
VectorizedSize
<
OutT
>
(
out
);
int
vec_size
=
platform
::
Get
VectorizedSize
<
OutT
>
(
out
);
if
(
!
std
::
is_same
<
InT
,
OutT
>::
value
&&
vec_size
==
4
&&
size
%
4
==
0
)
{
VecCastCUDAKernel
<
InT
,
OutT
,
4
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
ctx_
.
stream
()
>>>
(
...
...
paddle/fluid/operators/dropout_op.cu
浏览文件 @
c171eca2
...
...
@@ -38,7 +38,7 @@ namespace operators {
template
<
typename
T
,
typename
MaskType
>
__global__
void
RandomGenerator
(
const
size_t
n
,
uint64_t
seed
,
const
float
dropout_prob
,
const
T
*
src
,
MaskType
*
mask
_data
,
T
*
dst
,
MaskType
*
mask
,
T
*
dst
,
bool
is_upscale_in_train
,
uint64_t
increment
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
#ifdef PADDLE_WITH_HIP
...
...
@@ -49,36 +49,36 @@ __global__ void RandomGenerator(const size_t n, uint64_t seed,
curand_init
(
seed
,
idx
,
increment
,
&
state
);
#endif
MaskType
mask
;
T
dest
;
MaskType
mask_val
;
T
dst_val
;
T
factor
=
static_cast
<
T
>
(
1.0
f
/
(
1.0
f
-
dropout_prob
));
for
(;
idx
<
n
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
s
=
src
[
idx
];
T
s
rc_val
=
src
[
idx
];
#ifdef PADDLE_WITH_HIP
if
(
hiprand_uniform
(
&
state
)
<
dropout_prob
)
{
#else
if
(
curand_uniform
(
&
state
)
<
dropout_prob
)
{
#endif
mask
=
0
;
d
est
=
0
;
mask
_val
=
0
;
d
st_val
=
0
;
}
else
{
mask
=
1
;
if
(
is_upscale_in_train
)
{
dest
=
s
/
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
else
{
dest
=
s
;
}
mask_val
=
1
;
dst_val
=
is_upscale_in_train
?
src_val
*
factor
:
src_val
;
}
mask
_data
[
idx
]
=
mask
;
dst
[
idx
]
=
d
est
;
mask
[
idx
]
=
mask_val
;
dst
[
idx
]
=
d
st_val
;
}
}
template
<
typename
T
,
typename
MaskType
,
int
VecSize
>
__global__
void
VectorizedRandomGenerator
(
const
size_t
n
,
uint64_t
seed
,
const
float
dropout_prob
,
const
T
*
src
,
MaskType
*
mask
_data
,
T
*
dst
,
bool
is_upscale_in_train
,
const
T
*
src
,
MaskType
*
mask
,
T
*
dst
,
bool
is_upscale_in_train
,
uint64_t
increment
)
{
using
LoadT
=
platform
::
AlignedVector
<
T
,
VecSize
>
;
using
MaskLoadT
=
platform
::
AlignedVector
<
MaskType
,
VecSize
>
;
#ifdef PADDLE_WITH_HIP
int64_t
idx
=
hipBlockDim_x
*
hipBlockIdx_x
+
hipThreadIdx_x
;
hiprandStatePhilox4_32_10_t
state
;
...
...
@@ -89,43 +89,33 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
curand_init
(
seed
,
idx
,
increment
,
&
state
);
#endif
MaskType
mask
;
T
dest
;
using
LoadT
=
AlignedVector
<
T
,
VecSize
>
;
using
MaskLoadT
=
AlignedVector
<
MaskType
,
VecSize
>
;
T
factor
=
static_cast
<
T
>
(
1.0
f
/
(
1.0
f
-
dropout_prob
));
for
(
int
i
=
idx
*
VecSize
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
*
VecSize
)
{
T
src_vec
[
VecSize
]
;
LoadT
*
value
=
reinterpret_cast
<
LoadT
*>
(
&
src_vec
);
*
value
=
*
reinterpret_cast
<
const
LoadT
*>
(
&
src
[
i
]);
LoadT
src_val
;
platform
::
Load
<
T
,
VecSize
>
(
&
src
[
i
],
&
src_val
);
#ifdef PADDLE_WITH_HIP
float4
rand
=
hiprand_uniform4
(
&
state
);
#else
float4
rand
=
curand_uniform4
(
&
state
);
#endif
T
dest_vec
[
VecSize
]
;
Mask
Type
mask_vec
[
VecSize
]
;
LoadT
dst_val
;
Mask
LoadT
mask_val
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VecSize
;
ii
++
)
{
if
((
&
rand
.
x
)[
ii
]
<
dropout_prob
)
{
d
est_vec
[
ii
]
=
0
;
mask_v
ec
[
ii
]
=
0
;
for
(
int
j
=
0
;
j
<
VecSize
;
j
++
)
{
if
((
&
rand
.
x
)[
j
]
<
dropout_prob
)
{
d
st_val
[
j
]
=
0
;
mask_v
al
[
j
]
=
0
;
}
else
{
if
(
is_upscale_in_train
)
{
dest_vec
[
ii
]
=
src_vec
[
ii
]
*
factor
;
}
else
{
dest_vec
[
ii
]
=
src_vec
[
ii
];
}
mask_vec
[
ii
]
=
1
;
dst_val
[
j
]
=
is_upscale_in_train
?
src_val
[
j
]
*
factor
:
src_val
[
j
];
mask_val
[
j
]
=
1
;
}
}
*
(
reinterpret_cast
<
LoadT
*>
(
&
dst
[
i
]))
=
*
reinterpret_cast
<
LoadT
*>
(
&
dest_vec
[
0
]);
*
(
reinterpret_cast
<
MaskLoadT
*>
(
&
mask_data
[
i
]))
=
*
reinterpret_cast
<
MaskLoadT
*>
(
&
mask_vec
[
0
]);
platform
::
Store
<
T
,
VecSize
>
(
dst_val
,
&
dst
[
i
]);
platform
::
Store
<
MaskType
,
VecSize
>
(
mask_val
,
&
mask
[
i
]);
}
}
...
...
@@ -185,7 +175,7 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
// same as the previous calls.
uint64_t
seed_data
;
uint64_t
increment
;
int
vec_size
=
VectorizedSize
<
T
>
(
x_data
);
int
vec_size
=
platform
::
Get
VectorizedSize
<
T
>
(
x_data
);
auto
offset
=
((
x_numel
-
1
)
/
(
config
.
block_per_grid
.
x
*
config
.
thread_per_block
.
x
*
vec_size
)
+
1
)
*
...
...
paddle/fluid/operators/dropout_op.h
浏览文件 @
c171eca2
...
...
@@ -21,54 +21,36 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace
paddle
{
namespace
operators
{
// aligned vector generates vectorized load/store on CUDA
template
<
typename
T
,
int
Size
>
struct
alignas
(
sizeof
(
T
)
*
Size
)
AlignedVector
{
T
val
[
Size
];
};
template
<
typename
T
>
inline
int
VectorizedSize
(
const
T
*
pointer
)
{
uint64_t
address
=
reinterpret_cast
<
uint64_t
>
(
pointer
);
constexpr
int
vec4
=
std
::
alignment_of
<
AlignedVector
<
T
,
4
>>::
value
;
// NOLINT
if
(
address
%
vec4
==
0
)
{
return
4
;
}
return
1
;
}
#if defined(__NVCC__) || defined(__HIPCC__)
template
<
typename
T
,
typename
MaskType
,
int
VecSize
>
__global__
void
DropoutGradCUDAKernel
(
const
T
*
dout
,
const
MaskType
*
mask
,
const
T
factor
,
const
int64_t
size
,
T
*
dx
)
{
int64_t
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
using
LoadT
=
AlignedVector
<
T
,
VecSize
>
;
using
MaskLoadT
=
AlignedVector
<
MaskType
,
VecSize
>
;
using
LoadT
=
platform
::
AlignedVector
<
T
,
VecSize
>
;
using
MaskLoadT
=
platform
::
AlignedVector
<
MaskType
,
VecSize
>
;
int64_t
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(
int
i
=
idx
*
VecSize
;
i
<
size
;
i
+=
blockDim
.
x
*
gridDim
.
x
*
VecSize
)
{
T
dout_vec
[
VecSize
];
LoadT
*
dout_value
=
reinterpret_cast
<
LoadT
*>
(
&
dout_vec
);
*
dout_value
=
*
reinterpret_cast
<
const
LoadT
*>
(
&
dout
[
i
]);
LoadT
dout_val
;
platform
::
Load
<
T
,
VecSize
>
(
&
dout
[
i
],
&
dout_val
);
MaskType
mask_vec
[
VecSize
];
MaskLoadT
*
mask_value
=
reinterpret_cast
<
MaskLoadT
*>
(
&
mask_vec
);
*
mask_value
=
*
reinterpret_cast
<
const
MaskLoadT
*>
(
&
mask
[
i
]);
MaskLoadT
mask_val
;
platform
::
Load
<
MaskType
,
VecSize
>
(
&
mask
[
i
],
&
mask_val
);
T
dx_vec
[
VecSize
]
;
LoadT
dx_val
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VecSize
;
ii
++
)
{
dx_v
ec
[
ii
]
=
dout_vec
[
ii
]
*
static_cast
<
T
>
(
mask_vec
[
ii
])
*
factor
;
for
(
int
j
=
0
;
j
<
VecSize
;
j
++
)
{
dx_v
al
[
j
]
=
dout_val
[
j
]
*
static_cast
<
T
>
(
mask_val
[
j
])
*
factor
;
}
*
(
reinterpret_cast
<
LoadT
*>
(
&
dx
[
i
]))
=
*
reinterpret_cast
<
LoadT
*>
(
&
dx_vec
[
0
]);
platform
::
Store
<
T
,
VecSize
>
(
dx_val
,
&
dx
[
i
]);
}
}
#endif
...
...
@@ -187,7 +169,7 @@ class DropoutGradKernel : public framework::OpKernel<T> {
if
(
dropout_prob
==
1.0
f
)
{
dX
.
device
(
place
)
=
static_cast
<
T
>
(
0
)
*
dY
;
}
else
{
int
vec_size
=
VectorizedSize
<
T
>
(
grad_y
->
data
<
T
>
());
int
vec_size
=
platform
::
Get
VectorizedSize
<
T
>
(
grad_y
->
data
<
T
>
());
if
(
platform
::
is_gpu_place
(
context
.
GetPlace
())
&&
vec_size
==
4
&&
size
%
4
==
0
)
{
#if defined(__NVCC__) || defined(__HIPCC__)
...
...
paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
浏览文件 @
c171eca2
...
...
@@ -199,8 +199,8 @@ struct StridesCalculation {
template
<
typename
InT
,
typename
OutT
,
typename
Functor
,
ElementwiseType
ET
,
int
VecSize
,
int
kDims
>
struct
BroadcastArgsWrapper
{
using
InVecType
=
platform
::
Cuda
AlignedVector
<
InT
,
VecSize
>
;
using
OutVecType
=
platform
::
Cuda
AlignedVector
<
OutT
,
VecSize
>
;
using
InVecType
=
platform
::
AlignedVector
<
InT
,
VecSize
>
;
using
OutVecType
=
platform
::
AlignedVector
<
OutT
,
VecSize
>
;
OutT
*
out_data
;
OutVecType
*
vec_out_data
;
...
...
@@ -320,7 +320,7 @@ template <typename InT, typename OutT, typename BroadcastArgsWrapper,
ElementwiseType
ET
,
int
VecSize
>
__device__
inline
void
VectorizedBroadcastKernelImpl
(
BroadcastArgsWrapper
broadcast_wrapper
,
int
tid
)
{
using
OutVecType
=
platform
::
Cuda
AlignedVector
<
OutT
,
VecSize
>
;
using
OutVecType
=
platform
::
AlignedVector
<
OutT
,
VecSize
>
;
OutVecType
args_out
;
InT
ins
[
ET
];
InT
args
[
ET
][
VecSize
];
...
...
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
浏览文件 @
c171eca2
...
...
@@ -69,8 +69,8 @@ int GetVectorizedSizeForIO(const std::vector<const framework::Tensor *> &ins,
template
<
ElementwiseType
ET
,
int
VecSize
,
typename
InT
,
typename
OutT
>
struct
ElementwiseDataWrapper
{
using
InVecType
=
platform
::
Cuda
AlignedVector
<
InT
,
VecSize
>
;
using
OutVecType
=
platform
::
Cuda
AlignedVector
<
OutT
,
VecSize
>
;
using
InVecType
=
platform
::
AlignedVector
<
InT
,
VecSize
>
;
using
OutVecType
=
platform
::
AlignedVector
<
OutT
,
VecSize
>
;
const
InT
*
__restrict__
in_data
[
ET
];
OutT
*
out_data
;
...
...
@@ -117,8 +117,8 @@ template <ElementwiseType ET, int VecSize, typename ElementwiseWrapper,
typename
InT
,
typename
OutT
,
typename
Functor
>
__device__
inline
void
VectorizedKernelImpl
(
ElementwiseWrapper
data
,
Functor
func
,
int
tid
)
{
using
InVecType
=
platform
::
Cuda
AlignedVector
<
InT
,
VecSize
>
;
using
OutVecType
=
platform
::
Cuda
AlignedVector
<
OutT
,
VecSize
>
;
using
InVecType
=
platform
::
AlignedVector
<
InT
,
VecSize
>
;
using
OutVecType
=
platform
::
AlignedVector
<
OutT
,
VecSize
>
;
InVecType
ins_vec
[
ET
];
OutVecType
out_vec
;
InT
*
ins_ptr
[
ET
];
...
...
paddle/fluid/operators/fused/attn_bias_add.cu.h
浏览文件 @
c171eca2
...
...
@@ -96,36 +96,13 @@ __global__ void BroadcastKernelBinary(
kernel_primitives
::
WriteData
<
OutT
,
VecSize
,
1
,
1
>
(
out
+
fix
,
result
,
num
);
}
template
<
typename
T
>
int
GetVectorizedSizeImpl
(
const
T
*
pointer
)
{
constexpr
int
max_load_bits
=
128
;
int
valid_vec_size
=
max_load_bits
/
CHAR_BIT
/
sizeof
(
T
);
uint64_t
address
=
reinterpret_cast
<
uint64_t
>
(
pointer
);
constexpr
int
vec8
=
std
::
alignment_of
<
platform
::
CudaAlignedVector
<
T
,
8
>>::
value
;
// NOLINT
constexpr
int
vec4
=
std
::
alignment_of
<
platform
::
CudaAlignedVector
<
T
,
4
>>::
value
;
// NOLINT
constexpr
int
vec2
=
std
::
alignment_of
<
platform
::
CudaAlignedVector
<
T
,
2
>>::
value
;
// NOLINT
if
(
address
%
vec8
==
0
)
{
// Note: this line can change from 4 to 8 if it can improve the performance.
return
std
::
min
(
4
,
valid_vec_size
);
}
else
if
(
address
%
vec4
==
0
)
{
return
std
::
min
(
4
,
valid_vec_size
);
}
else
if
(
address
%
vec2
==
0
)
{
return
std
::
min
(
2
,
valid_vec_size
);
}
else
{
return
1
;
}
}
// bias add forward impl for "[m, n] + [n] = [m, n]"
template
<
typename
T
>
void
LaunchBiasAddFwKernel
(
const
platform
::
CUDADeviceContext
&
ctx
,
int
m
,
int
n
,
const
T
*
in0
,
const
T
*
in1
,
T
*
out
)
{
int
in_vec_size
=
std
::
min
(
GetVectorizedSizeImpl
<
T
>
(
in0
),
GetVectorizedSizeImpl
<
T
>
(
in1
));
int
out_vec_size
=
std
::
min
(
4
,
GetVectorizedSizeImpl
<
T
>
(
out
));
int
in_vec_size
=
std
::
min
(
platform
::
GetVectorizedSize
<
T
>
(
in0
),
platform
::
GetVectorizedSize
<
T
>
(
in1
));
int
out_vec_size
=
std
::
min
(
4
,
platform
::
GetVectorizedSize
<
T
>
(
out
));
int
vec_size
=
std
::
min
(
out_vec_size
,
in_vec_size
);
int
numel
=
m
*
n
;
...
...
paddle/fluid/platform/aligned_vector.h
0 → 100644
浏览文件 @
c171eca2
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.1 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.1
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/platform/hostdevice.h"
namespace
paddle
{
namespace
platform
{
// Aligned vector generates vectorized load/store on CUDA.
template
<
typename
T
,
int
Size
>
struct
alignas
(
sizeof
(
T
)
*
Size
)
AlignedVector
{
T
val
[
Size
];
HOSTDEVICE
inline
const
T
&
operator
[](
int
i
)
const
{
return
val
[
i
];
}
HOSTDEVICE
inline
T
&
operator
[](
int
i
)
{
return
val
[
i
];
}
};
template
<
typename
T
,
int
Size
>
HOSTDEVICE
inline
void
Load
(
const
T
*
addr
,
AlignedVector
<
T
,
Size
>*
vec
)
{
const
AlignedVector
<
T
,
Size
>*
addr_vec
=
reinterpret_cast
<
const
AlignedVector
<
T
,
Size
>*>
(
addr
);
*
vec
=
*
addr_vec
;
}
template
<
typename
T
,
int
Size
>
HOSTDEVICE
inline
void
Store
(
const
AlignedVector
<
T
,
Size
>&
vec
,
T
*
addr
)
{
AlignedVector
<
T
,
Size
>*
addr_vec
=
reinterpret_cast
<
AlignedVector
<
T
,
Size
>*>
(
addr
);
*
addr_vec
=
vec
;
}
/*
* Only the address of input data is the multiplier of 1,2,4, vectorized load
* with corresponding multiplier-value is possible. Moreover, the maximum length
* of vectorized load is 128 bits once. Hence, valid length of vectorized load
* shall be determined under both former constraints.
*/
template
<
typename
T
>
int
GetVectorizedSize
(
const
T
*
pointer
)
{
constexpr
int
max_load_bits
=
128
;
int
valid_vec_size
=
max_load_bits
/
CHAR_BIT
/
sizeof
(
T
);
uint64_t
address
=
reinterpret_cast
<
uint64_t
>
(
pointer
);
constexpr
int
vec8
=
std
::
alignment_of
<
AlignedVector
<
T
,
8
>>::
value
;
// NOLINT
constexpr
int
vec4
=
std
::
alignment_of
<
AlignedVector
<
T
,
4
>>::
value
;
// NOLINT
constexpr
int
vec2
=
std
::
alignment_of
<
AlignedVector
<
T
,
2
>>::
value
;
// NOLINT
if
(
address
%
vec8
==
0
)
{
/*
* Currently, decide to deal with no more than 4 data once while adopting
* vectorization load/store, if performance test shows that dealing with
* 8 data once in vectorization load/store does get optimized, return code
* below can be changed into " return std::min(8, valid_vec_size); " .
*/
return
std
::
min
(
4
,
valid_vec_size
);
}
else
if
(
address
%
vec4
==
0
)
{
return
std
::
min
(
4
,
valid_vec_size
);
}
else
if
(
address
%
vec2
==
0
)
{
return
std
::
min
(
2
,
valid_vec_size
);
}
else
{
return
1
;
}
}
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/fast_divmod.h
浏览文件 @
c171eca2
...
...
@@ -15,22 +15,17 @@ limitations under the License. */
#pragma once
#include <cstdint>
#include "paddle/fluid/platform/
hostdevice
.h"
#include "paddle/fluid/platform/
aligned_vector
.h"
#define INT_BITS 32
namespace
paddle
{
namespace
platform
{
template
<
typename
T
,
int
Size
>
struct
alignas
(
sizeof
(
T
)
*
Size
)
CudaAlignedVector
{
T
val
[
Size
];
};
struct
FastDivMod
{
// 1st value represents the result of input number divides by recorded divisor
// 2nd value represents the result of input number modulo by recorded divisor
using
DivModT
=
Cuda
AlignedVector
<
uint32_t
,
2
>
;
using
DivModT
=
AlignedVector
<
uint32_t
,
2
>
;
FastDivMod
()
{}
HOSTDEVICE
FastDivMod
(
uint32_t
d
)
:
divisor
(
d
)
{
...
...
@@ -65,39 +60,5 @@ struct FastDivMod {
uint32_t
multiplier
;
};
/*
* Only the address of input data is the multiplier of 1,2,4, vectorized load
* with corresponding multiplier-value is possible. Moreover, the maximum length
* of vectorized load is 128 bits once. Hence, valid length of vectorized load
* shall be determined under both former constraints.
*/
template
<
typename
T
>
int
GetVectorizedSize
(
const
T
*
pointer
)
{
constexpr
int
max_load_bits
=
128
;
int
valid_vec_size
=
max_load_bits
/
CHAR_BIT
/
sizeof
(
T
);
uint64_t
address
=
reinterpret_cast
<
uint64_t
>
(
pointer
);
constexpr
int
vec8
=
std
::
alignment_of
<
CudaAlignedVector
<
T
,
8
>>::
value
;
// NOLINT
constexpr
int
vec4
=
std
::
alignment_of
<
CudaAlignedVector
<
T
,
4
>>::
value
;
// NOLINT
constexpr
int
vec2
=
std
::
alignment_of
<
CudaAlignedVector
<
T
,
2
>>::
value
;
// NOLINT
if
(
address
%
vec8
==
0
)
{
/*
* Currently, decide to deal with no more than 4 data once while adopting
* vectorization load/store, if performance test shows that dealing with
* 8 data once in vectorization load/store does get optimized, return code
* below can be changed into " return std::min(8, valid_vec_size); " .
*/
return
std
::
min
(
4
,
valid_vec_size
);
}
else
if
(
address
%
vec4
==
0
)
{
return
std
::
min
(
4
,
valid_vec_size
);
}
else
if
(
address
%
vec2
==
0
)
{
return
std
::
min
(
2
,
valid_vec_size
);
}
else
{
return
1
;
}
}
}
// namespace platform
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录