Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
567e6bbc
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
567e6bbc
编写于
12月 08, 2021
作者:
C
crystal
提交者:
GitHub
12月 08, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
implementation of broadcast sub backward by reduce (#37754)
* add boardcast_sub * add boardcast_sub
上级
b4a67491
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
113 addition
and
8 deletion
+113
-8
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
+60
-1
paddle/fluid/operators/elementwise/elementwise_sub_op.h
paddle/fluid/operators/elementwise/elementwise_sub_op.h
+28
-7
paddle/fluid/operators/kernel_primitives/functor_primitives.h
...le/fluid/operators/kernel_primitives/functor_primitives.h
+14
-0
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
+11
-0
未找到文件。
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
浏览文件 @
567e6bbc
...
...
@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
...
...
@@ -30,12 +32,69 @@ static __global__ void SimpleElemwiseSubGradCUDAKernel(const T* dout,
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
while
(
col
<
size
)
{
dx
[
col
]
=
dout
[
col
];
if
(
dx
!=
nullptr
)
{
dx
[
col
]
=
dout
[
col
];
}
dy
[
col
]
=
-
dout
[
col
];
col
+=
blockDim
.
x
*
gridDim
.
x
;
}
}
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CUDADeviceContext
>::
value
>::
type
default_elementwise_sub_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
*
dout_data
=
dout
->
data
<
T
>
();
// dx
if
(
dx
!=
nullptr
)
{
auto
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
dx
->
dims
()
==
dout
->
dims
())
{
if
(
dx_data
!=
dout_data
)
{
framework
::
TensorCopy
(
*
dout
,
ctx
.
GetPlace
(),
ctx
.
template
device_context
<
platform
::
DeviceContext
>(),
dx
);
}
}
else
{
// For inplace strategy, dx will be stored in addr of dout, which makes
// the result of dy wrong.
if
(
dx
->
IsSharedBufferWith
(
*
dout
))
{
dx
->
clear
();
dx
->
mutable_data
<
T
>
(
x
->
dims
(),
ctx
.
GetPlace
());
}
std
::
vector
<
int
>
reduce_dims
=
GetReduceDim
(
x
->
dims
(),
out
->
dims
(),
axis
);
gpuStream_t
stream
=
ctx
.
cuda_device_context
().
stream
();
TensorReduceFunctorImpl
<
T
,
T
,
CustomSum
>
(
*
dout
,
dx
,
reduce_dims
,
stream
);
}
}
// dy
if
(
dy
!=
nullptr
)
{
auto
*
dy_data
=
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
dy
->
dims
()
==
dout
->
dims
())
{
if
(
dy_data
!=
dout_data
)
{
dim3
block_size
=
dim3
(
ELEMENTWISE_BLOCK_SIZE
,
1
);
auto
size
=
dy
->
numel
();
dim3
grid_size
=
dim3
(
(
size
+
ELEMENTWISE_BLOCK_SIZE
-
1
)
/
ELEMENTWISE_BLOCK_SIZE
,
1
);
SimpleElemwiseSubGradCUDAKernel
<
T
><<<
grid_size
,
block_size
,
0
,
ctx
.
template
device_context
<
plat
::
CUDADeviceContext
>().
stream
()
>>>
(
dout
->
data
<
T
>
(),
size
,
nullptr
,
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
}
else
{
std
::
vector
<
int
>
reduce_dims
=
GetReduceDim
(
y
->
dims
(),
out
->
dims
(),
axis
);
gpuStream_t
stream
=
ctx
.
cuda_device_context
().
stream
();
TensorReduceFunctorImpl
<
T
,
T
,
CustomSub
>
(
*
dout
,
dy
,
reduce_dims
,
stream
);
}
}
}
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
plat
::
CUDADeviceContext
>::
value
>::
type
...
...
paddle/fluid/operators/elementwise/elementwise_sub_op.h
浏览文件 @
567e6bbc
...
...
@@ -71,6 +71,21 @@ struct SubGradDY {
HOSTDEVICE
T
operator
()(
T
x
,
T
y
,
T
out
,
T
dout
)
const
{
return
-
dout
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
>::
type
default_elementwise_sub_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
ElemwiseExplicitGradCompute
<
DeviceContext
,
T
,
SubGradDX
<
T
>
,
SubGradDY
<
T
>>
(
ctx
,
*
x
,
*
y
,
*
out
,
*
dout
,
axis
,
dx
,
dy
,
SubGradDX
<
T
>
(),
SubGradDY
<
T
>
());
}
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
>::
type
...
...
@@ -79,13 +94,21 @@ elementwise_sub_grad(const framework::ExecutionContext& ctx,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
ElemwiseExplicitGradCompute
<
DeviceContext
,
T
,
SubGradDX
<
T
>
,
SubGradDY
<
T
>>
(
ctx
,
*
x
,
*
y
,
*
out
,
*
dout
,
axis
,
dx
,
dy
,
SubGradDX
<
T
>
(),
SubGradDY
<
T
>
());
default_elementwise_sub_grad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// cuda definition
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CUDADeviceContext
>::
value
>::
type
default_elementwise_sub_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
);
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CUDADeviceContext
>::
value
>::
type
...
...
@@ -108,15 +131,13 @@ class ElementwiseSubGradKernel : public ElemwiseGradKernel<T> {
auto
*
dout
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
// skip out
auto
*
out
=
dout
;
if
(
dx
!=
nullptr
&&
dy
!=
nullptr
&&
(
dx
->
dims
()
==
dy
->
dims
()))
{
elementwise_sub_grad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
}
else
{
ElemwiseExplicitGradCompute
<
DeviceContext
,
T
,
SubGradDX
<
T
>
,
SubGradDY
<
T
>>
(
ctx
,
*
x
,
*
y
,
*
out
,
*
dout
,
axis
,
dx
,
dy
,
SubGradDX
<
T
>
(),
SubGradDY
<
T
>
());
default_elementwise_sub_grad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
}
}
};
...
...
paddle/fluid/operators/kernel_primitives/functor_primitives.h
浏览文件 @
567e6bbc
...
...
@@ -86,6 +86,20 @@ struct DivideFunctor {
Tx
n_inv
;
};
/**
* @brief Default inverse functor
*/
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
InverseFunctor
{
HOSTDEVICE
inline
InverseFunctor
()
{}
HOSTDEVICE
explicit
inline
InverseFunctor
(
int
n
)
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
-
x
);
}
};
/**
* @brief Default unary square functor
*/
...
...
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
浏览文件 @
567e6bbc
...
...
@@ -64,6 +64,17 @@ struct CustomSum {
}
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomSub
{
using
Transformer
=
kps
::
InverseFunctor
<
Tx
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
0.0
f
);
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
b
+
a
;
}
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomMean
{
using
Transformer
=
kps
::
DivideFunctor
<
Tx
>
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录