Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
56c5e210
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
56c5e210
编写于
8月 22, 2021
作者:
Z
Zhang Zheng
提交者:
GitHub
8月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
implementation of broadcast add backward by reduce (#34143)
上级
e2241a43
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
90 addition
and
264 deletion
+90
-264
paddle/fluid/operators/elementwise/elementwise_add_op.cu
paddle/fluid/operators/elementwise/elementwise_add_op.cu
+52
-0
paddle/fluid/operators/elementwise/elementwise_add_op.h
paddle/fluid/operators/elementwise/elementwise_add_op.h
+18
-264
paddle/fluid/operators/elementwise/elementwise_op_function.h
paddle/fluid/operators/elementwise/elementwise_op_function.h
+20
-0
未找到文件。
paddle/fluid/operators/elementwise/elementwise_add_op.cu
浏览文件 @
56c5e210
...
...
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
...
...
@@ -83,6 +85,56 @@ static __global__ void SimpleElemwiseAddGradCUDAKernel(
}
}
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CUDADeviceContext
>::
value
>::
type
default_elementwise_add_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
*
dout_data
=
dout
->
data
<
T
>
();
// dx
if
(
dx
!=
nullptr
)
{
auto
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
dx
->
dims
()
==
dout
->
dims
())
{
if
(
dx_data
!=
dout_data
)
{
framework
::
TensorCopy
(
*
dout
,
ctx
.
GetPlace
(),
ctx
.
template
device_context
<
platform
::
DeviceContext
>(),
dx
);
}
}
else
{
// For inplace strategy, dx will be stored in addr of dout, which makes
// the result of dy wrong.
if
(
dx
->
IsSharedBufferWith
(
*
dout
))
{
dx
->
clear
();
dx
->
mutable_data
<
T
>
(
x
->
dims
(),
ctx
.
GetPlace
());
}
std
::
vector
<
int
>
reduce_dims
=
GetReduceDim
(
x
->
dims
(),
out
->
dims
(),
axis
);
gpuStream_t
stream
=
ctx
.
cuda_device_context
().
stream
();
TensorReduceFunctorImpl
<
T
,
T
,
CustomSum
>
(
*
dout
,
dx
,
reduce_dims
,
stream
);
}
}
// dy
if
(
dy
!=
nullptr
)
{
auto
*
dy_data
=
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
dy
->
dims
()
==
dout
->
dims
())
{
if
(
dy_data
!=
dout_data
)
{
framework
::
TensorCopy
(
*
dout
,
ctx
.
GetPlace
(),
ctx
.
template
device_context
<
platform
::
DeviceContext
>(),
dy
);
}
}
else
{
std
::
vector
<
int
>
reduce_dims
=
GetReduceDim
(
y
->
dims
(),
out
->
dims
(),
axis
);
gpuStream_t
stream
=
ctx
.
cuda_device_context
().
stream
();
TensorReduceFunctorImpl
<
T
,
T
,
CustomSum
>
(
*
dout
,
dy
,
reduce_dims
,
stream
);
}
}
}
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
plat
::
CUDADeviceContext
>::
value
>::
type
...
...
paddle/fluid/operators/elementwise/elementwise_add_op.h
浏览文件 @
56c5e210
...
...
@@ -85,13 +85,14 @@ struct IdentityGrad {
};
template
<
typename
DeviceContext
,
typename
T
>
void
default_elementwise_add_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
>::
type
default_elementwise_add_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
ElemwiseExplicitGradCompute
<
DeviceContext
,
T
,
IdentityGrad
<
T
>
,
...
...
@@ -133,167 +134,6 @@ elementwise_add_grad(const framework::ExecutionContext &ctx,
default_elementwise_add_grad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
}
#ifdef PADDLE_WITH_CUDA
#ifdef __NVCC__
template
<
typename
T
,
int
Size
>
struct
alignas
(
sizeof
(
T
)
*
Size
)
AlignedVector
{
T
val
[
Size
];
};
template
<
typename
T
>
inline
int
VectorizedSize
(
const
T
*
pointer
)
{
uint64_t
address
=
reinterpret_cast
<
uint64_t
>
(
pointer
);
constexpr
int
vec4
=
std
::
alignment_of
<
AlignedVector
<
T
,
4
>>::
value
;
// NOLINT
if
(
address
%
vec4
==
0
)
{
return
4
;
}
return
1
;
}
template
<
typename
T
,
int
BLOCK_W
,
int
BLOCK_H
>
__global__
void
MatrixColReduce
(
const
T
*
__restrict__
in
,
T
*
__restrict__
out
,
size_t
width
,
size_t
height
)
{
__shared__
T
sdata
[
BLOCK_H
][
BLOCK_W
+
1
];
size_t
idx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
size_t
width_stride
=
gridDim
.
x
*
blockDim
.
x
;
size_t
full_width
=
(
width
&
(
~
((
uint64_t
)(
BLOCK_W
-
1
))))
+
((
width
&
(
BLOCK_W
-
1
))
?
BLOCK_W
:
0
);
size_t
full_height
=
(
height
&
(
~
((
uint64_t
)(
BLOCK_H
-
1
))))
+
((
height
&
(
BLOCK_H
-
1
))
?
BLOCK_H
:
0
);
#pragma unroll
for
(
size_t
w
=
idx
;
w
<
full_width
;
w
+=
width_stride
)
{
sdata
[
threadIdx
.
y
][
threadIdx
.
x
]
=
0
;
__syncthreads
();
size_t
offset
=
w
+
threadIdx
.
y
*
width
;
#pragma unroll
for
(
size_t
h
=
threadIdx
.
y
;
h
<
full_height
;
h
+=
BLOCK_H
)
{
// block-stride loop across matrix height
sdata
[
threadIdx
.
y
][
threadIdx
.
x
]
+=
(
w
<
width
&&
h
<
height
)
?
in
[
offset
]
:
(
static_cast
<
T
>
(
0
));
offset
+=
width
*
BLOCK_H
;
}
__syncthreads
();
T
val
=
sdata
[
threadIdx
.
x
][
threadIdx
.
y
];
for
(
int
i
=
warpSize
>>
1
;
i
>
0
;
i
>>=
1
)
val
+=
platform
::
CudaShuffleXorSync
(
0xFFFFFFFF
,
val
,
i
);
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
sdata
[
0
][
threadIdx
.
y
]
=
val
;
__syncthreads
();
if
((
threadIdx
.
y
==
0
)
&&
((
w
)
<
width
))
out
[
w
]
=
sdata
[
0
][
threadIdx
.
x
];
}
}
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
template
<
int
SIZE
>
__global__
void
VecFP16MatrixColReduce
(
const
__half2
*
__restrict__
in
,
__half2
*
__restrict__
out
,
size_t
width
,
size_t
height
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
by
=
blockIdx
.
y
;
__half2
zero
=
__half2half2
(
static_cast
<
__half
>
(
0
));
const
int
cols
=
width
/
2
;
for
(;
idx
<
cols
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
__half2
sum
=
zero
;
for
(
int
row
=
0
;
row
<
SIZE
;
row
++
)
{
int
index
=
idx
+
(
row
+
by
*
SIZE
)
*
cols
;
sum
=
__hadd2
(
sum
,
in
[
index
]);
}
atomicAdd
(
&
(
out
[
idx
]),
sum
);
}
#endif
}
#endif
template
<
typename
T
>
__global__
void
MatrixReduceLongWidth
(
const
T
*
__restrict__
in
,
T
*
out
,
size_t
width
,
size_t
height
)
{
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
for
(;
idx
<
width
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
sum
=
static_cast
<
T
>
(
0
);
for
(
int
row
=
0
;
row
<
height
;
row
++
)
{
sum
+=
in
[
idx
+
row
*
width
];
}
out
[
idx
]
=
sum
;
}
}
template
<
typename
T
,
int
VEC_SIZE
>
__global__
void
VecMatrixReduceLongWidth
(
const
T
*
__restrict__
in
,
T
*
out
,
size_t
width
,
size_t
height
)
{
using
LoadT
=
AlignedVector
<
T
,
VEC_SIZE
>
;
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
w
=
idx
*
VEC_SIZE
;
int
width_stride
=
blockDim
.
x
*
gridDim
.
x
*
VEC_SIZE
;
for
(;
w
<
width
;
w
+=
width_stride
)
{
T
zero
=
static_cast
<
T
>
(
0
);
T
sum
[
VEC_SIZE
]
=
{
zero
};
T
tmp_vec
[
VEC_SIZE
]
=
{
zero
};
LoadT
*
tmp_ptr
=
reinterpret_cast
<
LoadT
*>
(
&
tmp_vec
);
for
(
int
row
=
0
;
row
<
height
;
row
++
)
{
int
offset
=
width
*
row
+
w
;
*
tmp_ptr
=
*
reinterpret_cast
<
const
LoadT
*>
(
&
in
[
offset
]);
for
(
int
v
=
0
;
v
<
VEC_SIZE
;
v
++
)
{
sum
[
v
]
+=
tmp_vec
[
v
];
}
}
for
(
int
v
=
0
;
v
<
VEC_SIZE
;
v
++
)
out
[
w
+
v
]
=
sum
[
v
];
}
}
#endif
#endif
bool
static
RunSpecialDims
(
const
framework
::
DDim
&
dx_dims
,
const
framework
::
DDim
&
dy_dims
,
const
framework
::
DDim
&
dout_dims
,
int
axis
)
{
auto
smaller_dims
=
dx_dims
;
auto
bigger_dims
=
dy_dims
;
auto
smaller_dims_size
=
smaller_dims
.
size
();
auto
bigger_dims_size
=
bigger_dims
.
size
();
int
smaller_ignore_size
=
0
;
int
bigger_ignore_size
=
0
;
for
(
int
i
=
0
;
i
<
smaller_dims_size
;
i
++
)
{
if
(
smaller_dims
[
i
]
==
1
)
smaller_ignore_size
++
;
else
break
;
}
for
(
int
i
=
0
;
i
<
bigger_dims_size
;
i
++
)
{
if
(
bigger_dims
[
i
]
==
1
)
bigger_ignore_size
++
;
else
break
;
}
int
smaller_real_size
=
smaller_dims
.
size
()
-
smaller_ignore_size
;
int
bigger_real_size
=
bigger_dims
.
size
()
-
bigger_ignore_size
;
if
(
smaller_real_size
==
bigger_real_size
)
return
false
;
if
(
bigger_real_size
<
smaller_real_size
)
{
smaller_dims
=
dy_dims
;
bigger_dims
=
dx_dims
;
std
::
swap
(
smaller_real_size
,
bigger_real_size
);
}
int
big_size
=
bigger_dims
.
size
();
int
small_size
=
smaller_dims
.
size
();
for
(
int
i
=
1
;
i
<=
smaller_real_size
;
i
++
)
{
if
(
bigger_dims
[
big_size
-
i
]
!=
smaller_dims
[
small_size
-
i
])
return
false
;
}
if
(
axis
!=
-
1
&&
(
axis
!=
(
bigger_real_size
-
smaller_real_size
)))
{
return
false
;
}
return
true
;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// cuda definition
template
<
typename
DeviceContext
,
typename
T
>
...
...
@@ -304,6 +144,16 @@ elementwise_add_grad(const framework::ExecutionContext &ctx,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
);
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CUDADeviceContext
>::
value
>::
type
default_elementwise_add_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
);
#endif
template
<
typename
DeviceContext
,
typename
T
>
...
...
@@ -322,102 +172,6 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
// skip out
auto
*
out
=
dout
;
// TODO(@wangchaochaohu, zhouwei35): Fix conv_transpose2d API(dataformat NHWC)
// error in Windows
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#ifdef __NVCC__
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
if
(
ctx
.
GetPlace
()
==
platform
::
CUDAPlace
()
&&
dx
!=
nullptr
&&
dy
!=
nullptr
&&
dout
!=
nullptr
&&
dx
->
numel
()
!=
dy
->
numel
()
&&
RunSpecialDims
(
dx
->
dims
(),
dy
->
dims
(),
dout
->
dims
(),
axis
))
{
auto
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
dy_data
=
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
dout_data
=
dout
->
data
<
T
>
();
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
auto
*
out_data
=
dx_data
;
int
width
=
dx
->
numel
();
int
height
=
dout
->
numel
()
/
width
;
if
(
dx
->
dims
()
==
dout
->
dims
())
{
width
=
dy
->
numel
();
height
=
dout
->
numel
()
/
width
;
out_data
=
dy_data
;
framework
::
TensorCopy
(
*
dout
,
ctx
.
GetPlace
(),
ctx
.
template
device_context
<
platform
::
DeviceContext
>(),
dx
);
}
else
{
framework
::
TensorCopy
(
*
dout
,
ctx
.
GetPlace
(),
ctx
.
template
device_context
<
platform
::
DeviceContext
>(),
dy
);
}
// special optimization using cub
if
(
width
==
1
)
{
int
nums
=
height
;
size_t
temp_storage_bytes
=
0
;
auto
err
=
cub
::
DeviceReduce
::
Sum
(
nullptr
,
temp_storage_bytes
,
dout_data
,
out_data
,
nums
,
stream
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
err
);
framework
::
Tensor
tmp
;
auto
*
temp_storage
=
tmp
.
mutable_data
<
uint8_t
>
(
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
temp_storage_bytes
)}),
ctx
.
GetPlace
());
err
=
cub
::
DeviceReduce
::
Sum
(
temp_storage
,
temp_storage_bytes
,
dout_data
,
out_data
,
nums
,
stream
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
err
);
return
;
}
constexpr
int
block_x
=
32
;
constexpr
int
block_y
=
32
;
dim3
blocks
(
block_x
,
block_y
);
int
max_physical_threads
=
ctx
.
cuda_device_context
().
GetMaxPhysicalThreadCount
();
int
max_blocks
=
std
::
max
(
max_physical_threads
/
(
block_x
*
block_y
),
1
);
int
theory_block
=
(
width
+
blocks
.
x
-
1
)
/
blocks
.
x
;
dim3
grids
(
std
::
min
(
theory_block
,
max_blocks
));
#if CUDA_VERSION >= 10000
if
(
std
::
is_same
<
T
,
paddle
::
platform
::
float16
>::
value
&&
width
<
2048
&&
width
%
2
==
0
&&
height
%
64
==
0
)
{
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
functor
;
if
(
dout
->
dims
()
==
dx
->
dims
())
functor
(
dev_ctx
,
dy
,
static_cast
<
T
>
(
0
));
else
functor
(
dev_ctx
,
dx
,
static_cast
<
T
>
(
0
));
const
__half2
*
ptr1
=
reinterpret_cast
<
const
__half2
*>
(
dout_data
);
__half2
*
ptr2
=
reinterpret_cast
<
__half2
*>
(
out_data
);
const
int
threads
=
128
;
dim3
grid
(
1
,
(
height
+
64
-
1
)
/
64
);
VecFP16MatrixColReduce
<
64
><<<
grid
,
threads
,
0
,
stream
>>>
(
ptr1
,
ptr2
,
width
,
height
);
return
;
}
#endif
if
(
width
/
height
<
32
)
{
MatrixColReduce
<
T
,
block_x
,
block_y
><<<
grids
,
blocks
,
0
,
stream
>>>
(
dout_data
,
out_data
,
width
,
height
);
}
else
{
size_t
thread_nums
=
1024
;
size_t
block_nums
=
(
width
+
thread_nums
-
1
)
/
thread_nums
;
int
vec_size
=
VectorizedSize
<
T
>
(
dout_data
);
if
(
vec_size
==
4
&&
width
%
4
==
0
)
{
block_nums
=
(
width
/
vec_size
+
thread_nums
-
1
)
/
thread_nums
;
VecMatrixReduceLongWidth
<
T
,
4
><<<
block_nums
,
thread_nums
,
0
,
stream
>>>
(
dout_data
,
out_data
,
width
,
height
);
}
else
{
MatrixReduceLongWidth
<
T
><<<
block_nums
,
thread_nums
,
0
,
stream
>>>
(
dout_data
,
out_data
,
width
,
height
);
}
}
return
;
}
#endif
#endif
// Special case when dy is not needed and dx doesn't reduce
if
(
dx
!=
nullptr
&&
dy
==
nullptr
&&
dx
->
dims
()
==
dout
->
dims
())
{
VLOG
(
4
)
<<
"Special case when dy is not needed and dx doesn't "
...
...
paddle/fluid/operators/elementwise/elementwise_op_function.h
浏览文件 @
56c5e210
...
...
@@ -3038,5 +3038,25 @@ static inline void GetDoubleGradSafeTensor(
}
}
// for broadcast backwards
static
inline
std
::
vector
<
int
>
GetReduceDim
(
const
framework
::
DDim
&
in
,
const
framework
::
DDim
&
out
,
int
axis
)
{
axis
=
(
axis
==
-
1
?
std
::
abs
(
static_cast
<
int
>
(
out
.
size
()
-
in
.
size
()))
:
axis
);
std
::
vector
<
int
>
dims
;
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
dims
.
push_back
(
i
);
}
for
(
int
i
=
0
;
i
<
in
.
size
();
++
i
)
{
if
(
out
[
i
+
axis
]
!=
in
[
i
])
{
dims
.
push_back
(
i
+
axis
);
}
}
for
(
int
i
=
axis
+
in
.
size
();
i
<
out
.
size
();
++
i
)
{
dims
.
push_back
(
i
);
}
return
dims
;
}
}
// namespace operators
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录