Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7a1e1193
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7a1e1193
编写于
1月 27, 2022
作者:
Y
YuanRisheng
提交者:
GitHub
1月 27, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor elementwise sub grad (#39225)
上级
5631da9c
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
336 addition
and
198 deletion
+336
-198
paddle/fluid/operators/elementwise/elementwise_div_op.h
paddle/fluid/operators/elementwise/elementwise_div_op.h
+16
-1
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
+0
-97
paddle/fluid/operators/elementwise/elementwise_sub_op.h
paddle/fluid/operators/elementwise/elementwise_sub_op.h
+22
-98
paddle/pten/core/kernel_alias_name.h
paddle/pten/core/kernel_alias_name.h
+1
-0
paddle/pten/kernels/cpu/elementwise.h
paddle/pten/kernels/cpu/elementwise.h
+34
-2
paddle/pten/kernels/cpu/elementwise_grad_kernel.cc
paddle/pten/kernels/cpu/elementwise_grad_kernel.cc
+54
-0
paddle/pten/kernels/elementwise_grad_kernel.h
paddle/pten/kernels/elementwise_grad_kernel.h
+18
-0
paddle/pten/kernels/gpu/elementwise.h
paddle/pten/kernels/gpu/elementwise.h
+108
-0
paddle/pten/kernels/gpu/elementwise_grad_kernel.cu
paddle/pten/kernels/gpu/elementwise_grad_kernel.cu
+60
-0
paddle/pten/kernels/impl/elementwise_grad_kernel_impl.h
paddle/pten/kernels/impl/elementwise_grad_kernel_impl.h
+23
-0
未找到文件。
paddle/fluid/operators/elementwise/elementwise_div_op.h
浏览文件 @
7a1e1193
...
@@ -16,11 +16,26 @@ limitations under the License. */
...
@@ -16,11 +16,26 @@ limitations under the License. */
#include <vector>
#include <vector>
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
void
default_elementwise_sub
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
x_dims
=
x
->
dims
();
auto
y_dims
=
y
->
dims
();
if
(
x_dims
.
size
()
>=
y_dims
.
size
())
{
ElementwiseComputeEx
<
SubFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
SubFunctor
<
T
>
(),
z
);
}
else
{
ElementwiseComputeEx
<
InverseSubFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
InverseSubFunctor
<
T
>
(),
z
);
}
}
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
void
default_elementwise_div
(
const
framework
::
ExecutionContext
&
ctx
,
void
default_elementwise_div
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
x
,
...
...
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
浏览文件 @
7a1e1193
...
@@ -17,103 +17,6 @@ limitations under the License. */
...
@@ -17,103 +17,6 @@ limitations under the License. */
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
namespace
plat
=
paddle
::
platform
;
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
static
__global__
void
SimpleElemwiseSubGradCUDAKernel
(
const
T
*
dout
,
int64_t
size
,
T
*
dx
,
T
*
dy
)
{
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
while
(
col
<
size
)
{
if
(
dx
!=
nullptr
)
{
dx
[
col
]
=
dout
[
col
];
}
dy
[
col
]
=
-
dout
[
col
];
col
+=
blockDim
.
x
*
gridDim
.
x
;
}
}
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CUDADeviceContext
>::
value
>::
type
default_elementwise_sub_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
*
dout_data
=
dout
->
data
<
T
>
();
// dx
if
(
dx
!=
nullptr
)
{
auto
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
dx
->
dims
()
==
dout
->
dims
())
{
if
(
dx_data
!=
dout_data
)
{
framework
::
TensorCopy
(
*
dout
,
ctx
.
GetPlace
(),
ctx
.
template
device_context
<
platform
::
DeviceContext
>(),
dx
);
}
}
else
{
// For inplace strategy, dx will be stored in addr of dout, which makes
// the result of dy wrong.
if
(
dx
->
IsSharedBufferWith
(
*
dout
))
{
dx
->
clear
();
dx
->
mutable_data
<
T
>
(
x
->
dims
(),
ctx
.
GetPlace
());
}
std
::
vector
<
int
>
reduce_dims
=
GetReduceDim
(
x
->
dims
(),
out
->
dims
(),
axis
);
gpuStream_t
stream
=
ctx
.
cuda_device_context
().
stream
();
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
T
>>
(
*
dout
,
dx
,
kps
::
IdentityFunctor
<
T
>
(),
reduce_dims
,
stream
);
}
}
// dy
if
(
dy
!=
nullptr
)
{
auto
*
dy_data
=
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
dy
->
dims
()
==
dout
->
dims
())
{
if
(
dy_data
!=
dout_data
)
{
dim3
block_size
=
dim3
(
PREDEFINED_BLOCK_SIZE
,
1
);
auto
size
=
dy
->
numel
();
dim3
grid_size
=
dim3
((
size
+
PREDEFINED_BLOCK_SIZE
-
1
)
/
PREDEFINED_BLOCK_SIZE
,
1
);
SimpleElemwiseSubGradCUDAKernel
<
T
><<<
grid_size
,
block_size
,
0
,
ctx
.
template
device_context
<
plat
::
CUDADeviceContext
>().
stream
()
>>>
(
dout
->
data
<
T
>
(),
size
,
nullptr
,
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
}
else
{
std
::
vector
<
int
>
reduce_dims
=
GetReduceDim
(
y
->
dims
(),
out
->
dims
(),
axis
);
gpuStream_t
stream
=
ctx
.
cuda_device_context
().
stream
();
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
InverseFunctor
<
T
>>
(
*
dout
,
dy
,
kps
::
InverseFunctor
<
T
>
(),
reduce_dims
,
stream
);
}
}
}
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
plat
::
CUDADeviceContext
>::
value
>::
type
elementwise_sub_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
dim3
block_size
=
dim3
(
PREDEFINED_BLOCK_SIZE
,
1
);
auto
size
=
x
->
numel
();
dim3
grid_size
=
dim3
((
size
+
PREDEFINED_BLOCK_SIZE
-
1
)
/
PREDEFINED_BLOCK_SIZE
,
1
);
SimpleElemwiseSubGradCUDAKernel
<
T
><<<
grid_size
,
block_size
,
0
,
ctx
.
template
device_context
<
plat
::
CUDADeviceContext
>().
stream
()
>>>
(
dout
->
data
<
T
>
(),
size
,
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
}
// namespace operators
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
elementwise_sub
,
elementwise_sub
,
ops
::
ElementwiseSubKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
ElementwiseSubKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
...
...
paddle/fluid/operators/elementwise/elementwise_sub_op.h
浏览文件 @
7a1e1193
...
@@ -17,26 +17,11 @@ limitations under the License. */
...
@@ -17,26 +17,11 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/pten/kernels/elementwise_grad_kernel.h"
#include "paddle/pten/kernels/math_kernel.h"
#include "paddle/pten/kernels/math_kernel.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
void
default_elementwise_sub
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
x_dims
=
x
->
dims
();
auto
y_dims
=
y
->
dims
();
if
(
x_dims
.
size
()
>=
y_dims
.
size
())
{
ElementwiseComputeEx
<
SubFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
SubFunctor
<
T
>
(),
z
);
}
else
{
ElementwiseComputeEx
<
InverseSubFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
InverseSubFunctor
<
T
>
(),
z
);
}
}
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
ElementwiseSubKernel
:
public
framework
::
OpKernel
<
T
>
{
class
ElementwiseSubKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -48,76 +33,13 @@ class ElementwiseSubKernel : public framework::OpKernel<T> {
...
@@ -48,76 +33,13 @@ class ElementwiseSubKernel : public framework::OpKernel<T> {
auto
&
dev_ctx
=
ctx
.
device_context
<
DeviceContext
>
();
auto
&
dev_ctx
=
ctx
.
device_context
<
DeviceContext
>
();
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
pt_x
=
paddle
::
experimental
::
MakePtenDenseTensor
(
*
x
);
auto
pt_y
=
paddle
::
experimental
::
MakePtenDenseTensor
(
*
y
);
auto
pt_z
=
paddle
::
experimental
::
MakePtenDenseTensor
(
*
z
);
pten
::
SubtractRawKernel
<
T
>
(
pten
::
SubtractRawKernel
<
T
>
(
static_cast
<
const
typename
framework
::
ConvertToPtenContext
<
static_cast
<
const
typename
framework
::
ConvertToPtenContext
<
DeviceContext
>::
TYPE
&>
(
dev_ctx
),
DeviceContext
>::
TYPE
&>
(
dev_ctx
),
*
pt_x
.
get
(),
*
pt_y
.
get
(),
axis
,
pt_z
.
get
()
);
*
x
,
*
y
,
axis
,
z
);
}
}
};
};
template
<
typename
T
>
struct
SubGradDX
{
HOSTDEVICE
T
operator
()(
T
x
,
T
y
,
T
out
,
T
dout
)
const
{
return
dout
;
}
};
template
<
typename
T
>
struct
SubGradDY
{
HOSTDEVICE
T
operator
()(
T
x
,
T
y
,
T
out
,
T
dout
)
const
{
return
-
dout
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
>::
type
default_elementwise_sub_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
const
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>();
pten
::
ElemwiseExplicitGradCompute
<
T
,
SubGradDX
<
T
>
,
SubGradDY
<
T
>>
(
dev_ctx
,
*
x
,
*
y
,
*
out
,
*
dout
,
axis
,
dx
,
dy
,
SubGradDX
<
T
>
(),
SubGradDY
<
T
>
());
}
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
>::
type
elementwise_sub_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
default_elementwise_sub_grad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// cuda definition
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CUDADeviceContext
>::
value
>::
type
default_elementwise_sub_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
);
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CUDADeviceContext
>::
value
>::
type
elementwise_sub_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
);
#endif
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
ElementwiseSubGradKernel
:
public
ElemwiseGradKernel
<
T
>
{
class
ElementwiseSubGradKernel
:
public
ElemwiseGradKernel
<
T
>
{
public:
public:
...
@@ -130,14 +52,13 @@ class ElementwiseSubGradKernel : public ElemwiseGradKernel<T> {
...
@@ -130,14 +52,13 @@ class ElementwiseSubGradKernel : public ElemwiseGradKernel<T> {
auto
*
dout
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dout
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
*
dy
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
// skip out
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
*
out
=
dout
;
auto
&
dev_ctx
=
ctx
.
device_context
<
DeviceContext
>
();
if
(
dx
!=
nullptr
&&
dy
!=
nullptr
&&
(
dx
->
dims
()
==
dy
->
dims
()))
{
elementwise_sub_grad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
pten
::
SubtractGradKernel
<
T
>
(
}
else
{
static_cast
<
const
typename
framework
::
ConvertToPtenContext
<
default_elementwise_sub_grad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
DeviceContext
>::
TYPE
&>
(
dev_ctx
),
dy
);
*
x
,
*
y
,
*
dout
,
axis
,
dx
,
dy
);
}
}
}
};
};
...
@@ -153,18 +74,21 @@ class ElementwiseSubDoubleGradKernel : public framework::OpKernel<T> {
...
@@ -153,18 +74,21 @@ class ElementwiseSubDoubleGradKernel : public framework::OpKernel<T> {
auto
*
ddy
=
ctx
.
Input
<
Tensor
>
(
"DDY"
);
auto
*
ddy
=
ctx
.
Input
<
Tensor
>
(
"DDY"
);
auto
*
ddout
=
ctx
.
Output
<
Tensor
>
(
"DDOut"
);
auto
*
ddout
=
ctx
.
Output
<
Tensor
>
(
"DDOut"
);
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
&
dev_ctx
=
ctx
.
device_context
<
DeviceContext
>
();
// DDOut = ddx - ddy
paddle
::
optional
<
const
pten
::
DenseTensor
&>
ddx_optional
=
paddle
::
none
;
if
(
ddout
)
{
paddle
::
optional
<
const
pten
::
DenseTensor
&>
ddy_optional
=
paddle
::
none
;
Tensor
ddx_safe
,
ddy_safe
;
if
(
ddx
!=
nullptr
)
{
GetDoubleGradSafeTensor
<
DeviceContext
,
T
>
(
ctx
,
dout
,
ddx
,
&
ddx_safe
);
ddx_optional
=
*
ddx
;
GetDoubleGradSafeTensor
<
DeviceContext
,
T
>
(
ctx
,
y
,
ddy
,
&
ddy_safe
);
ddout
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
ElementwiseComputeEx
<
SubFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
&
ddx_safe
,
&
ddy_safe
,
axis
,
SubFunctor
<
T
>
(),
ddout
);
}
}
if
(
ddy
!=
nullptr
)
{
ddy_optional
=
*
ddy
;
}
pten
::
SubtractDoubleGradKernel
<
T
>
(
static_cast
<
const
typename
framework
::
ConvertToPtenContext
<
DeviceContext
>::
TYPE
&>
(
dev_ctx
),
*
y
,
ddx_optional
,
ddy_optional
,
*
dout
,
axis
,
ddout
);
}
}
};
};
...
...
paddle/pten/core/kernel_alias_name.h
浏览文件 @
7a1e1193
...
@@ -25,6 +25,7 @@ const std::unordered_map<std::string, std::string> kernel_alias_name_map = {
...
@@ -25,6 +25,7 @@ const std::unordered_map<std::string, std::string> kernel_alias_name_map = {
{
"elementwise_div"
,
"divide_raw"
},
{
"elementwise_div"
,
"divide_raw"
},
{
"elementwise_mul"
,
"muliply_raw"
},
{
"elementwise_mul"
,
"muliply_raw"
},
{
"elementwise_sub"
,
"subtract_raw"
},
{
"elementwise_sub"
,
"subtract_raw"
},
{
"elementwise_sub_grad"
,
"subtract_grad"
},
{
"fill_any_like"
,
"full_like"
},
{
"fill_any_like"
,
"full_like"
},
{
"fill_constant"
,
"full"
},
{
"fill_constant"
,
"full"
},
{
"flatten_contiguous_range"
,
"flatten"
},
{
"flatten_contiguous_range"
,
"flatten"
},
...
...
paddle/pten/kernels/cpu/elementwise.h
浏览文件 @
7a1e1193
...
@@ -743,8 +743,11 @@ void ElemwiseExplicitGradCompute(const CPUContext& dev_ctx,
...
@@ -743,8 +743,11 @@ void ElemwiseExplicitGradCompute(const CPUContext& dev_ctx,
}
}
}
}
// Add Grad
/*
******************************
Add Grad
******************************
*/
template
<
typename
T
>
template
<
typename
T
>
struct
IdentityGrad
{
struct
IdentityGrad
{
HOSTDEVICE
T
operator
()(
T
x
,
T
y
,
T
out
,
T
dout
)
const
{
return
dout
;
}
HOSTDEVICE
T
operator
()(
T
x
,
T
y
,
T
out
,
T
dout
)
const
{
return
dout
;
}
...
@@ -786,4 +789,33 @@ elementwise_add_grad(const CPUContext& ctx,
...
@@ -786,4 +789,33 @@ elementwise_add_grad(const CPUContext& ctx,
ctx
,
x
,
y
,
out
,
dout
,
axis
,
dx
,
dy
,
IdentityGrad
<
T
>
(),
IdentityGrad
<
T
>
());
ctx
,
x
,
y
,
out
,
dout
,
axis
,
dx
,
dy
,
IdentityGrad
<
T
>
(),
IdentityGrad
<
T
>
());
}
}
/*
******************************
Sub Grad
******************************
*/
template
<
typename
T
>
struct
SubGradDX
{
HOSTDEVICE
T
operator
()(
T
x
,
T
y
,
T
out
,
T
dout
)
const
{
return
dout
;
}
};
template
<
typename
T
>
struct
SubGradDY
{
HOSTDEVICE
T
operator
()(
T
x
,
T
y
,
T
out
,
T
dout
)
const
{
return
-
dout
;
}
};
template
<
typename
T
>
void
elementwise_sub_grad
(
const
CPUContext
&
ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
DenseTensor
&
out
,
const
DenseTensor
&
dout
,
DenseTensor
*
dx
,
DenseTensor
*
dy
,
int
axis
=
-
1
)
{
ElemwiseExplicitGradCompute
<
T
,
SubGradDX
<
T
>
,
SubGradDY
<
T
>>
(
ctx
,
x
,
y
,
out
,
dout
,
axis
,
dx
,
dy
,
SubGradDX
<
T
>
(),
SubGradDY
<
T
>
());
}
}
// namespace pten
}
// namespace pten
paddle/pten/kernels/cpu/elementwise_grad_kernel.cc
浏览文件 @
7a1e1193
...
@@ -92,6 +92,38 @@ void AddTripleGradKernel(const Context& dev_ctx,
...
@@ -92,6 +92,38 @@ void AddTripleGradKernel(const Context& dev_ctx,
dev_ctx
,
ddx
,
ddy
,
d_ddout
,
axis
,
d_ddx
,
d_ddy
,
AddGradFunc
<
T
>
);
dev_ctx
,
ddx
,
ddy
,
d_ddout
,
axis
,
d_ddx
,
d_ddy
,
AddGradFunc
<
T
>
);
}
}
template
<
typename
T
,
typename
Context
>
void
SubtractGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
DenseTensor
&
dout
,
int
axis
,
DenseTensor
*
dx
,
DenseTensor
*
dy
)
{
// skip out
auto
*
out
=
&
dout
;
elementwise_sub_grad
<
T
>
(
dev_ctx
,
x
,
y
,
*
out
,
dout
,
dx
,
dy
,
axis
);
}
template
<
typename
T
,
typename
Context
>
void
SubtractDoubleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
y
,
paddle
::
optional
<
const
DenseTensor
&>
ddx
,
paddle
::
optional
<
const
DenseTensor
&>
ddy
,
const
DenseTensor
&
dout
,
int
axis
,
DenseTensor
*
ddout
)
{
pten
::
SubtractDoubleGradImpl
<
T
>
(
dev_ctx
,
y
,
ddx
,
ddy
,
dout
,
axis
,
ddout
,
ElementwiseCompute
<
funcs
::
SubtractFunctor
<
T
>
,
T
>
);
}
}
// namespace pten
}
// namespace pten
PT_REGISTER_KERNEL
(
add_grad
,
PT_REGISTER_KERNEL
(
add_grad
,
...
@@ -126,3 +158,25 @@ PT_REGISTER_KERNEL(add_triple_grad,
...
@@ -126,3 +158,25 @@ PT_REGISTER_KERNEL(add_triple_grad,
int64_t
,
int64_t
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_KERNEL
(
subtract_grad
,
CPU
,
ALL_LAYOUT
,
pten
::
SubtractGradKernel
,
float
,
double
,
int
,
int64_t
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_KERNEL
(
subtract_double_grad
,
CPU
,
ALL_LAYOUT
,
pten
::
SubtractDoubleGradKernel
,
float
,
double
,
int
,
int64_t
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle/pten/kernels/elementwise_grad_kernel.h
浏览文件 @
7a1e1193
...
@@ -46,4 +46,22 @@ void AddTripleGradKernel(const Context& dev_ctx,
...
@@ -46,4 +46,22 @@ void AddTripleGradKernel(const Context& dev_ctx,
DenseTensor
*
d_ddx
,
DenseTensor
*
d_ddx
,
DenseTensor
*
d_ddy
);
DenseTensor
*
d_ddy
);
template
<
typename
T
,
typename
Context
>
void
SubtractGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
DenseTensor
&
dout
,
int
axis
,
DenseTensor
*
dx
,
DenseTensor
*
dy
);
template
<
typename
T
,
typename
Context
>
void
SubtractDoubleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
y
,
paddle
::
optional
<
const
DenseTensor
&>
ddx
,
paddle
::
optional
<
const
DenseTensor
&>
ddy
,
const
DenseTensor
&
dout
,
int
axis
,
DenseTensor
*
ddout
);
}
// namespace pten
}
// namespace pten
paddle/pten/kernels/gpu/elementwise.h
浏览文件 @
7a1e1193
...
@@ -1952,6 +1952,12 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx,
...
@@ -1952,6 +1952,12 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx,
}
}
}
}
/*
******************************
Add Grad
******************************
*/
template
<
typename
T
>
template
<
typename
T
>
static
__global__
void
SimpleElemwiseAddGradCUDAKernel
(
static
__global__
void
SimpleElemwiseAddGradCUDAKernel
(
const
T
*
__restrict__
dout
,
int
size
,
int
vec_size
,
T
*
dx
,
T
*
dy
)
{
const
T
*
__restrict__
dout
,
int
size
,
int
vec_size
,
T
*
dx
,
T
*
dy
)
{
...
@@ -2078,4 +2084,106 @@ void elementwise_add_grad(const GPUContext &ctx,
...
@@ -2078,4 +2084,106 @@ void elementwise_add_grad(const GPUContext &ctx,
}
}
}
}
/*
******************************
Sub Grad
******************************
*/
template
<
typename
T
>
static
__global__
void
SimpleElemwiseSubGradCUDAKernel
(
const
T
*
dout
,
int64_t
size
,
T
*
dx
,
T
*
dy
)
{
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
while
(
col
<
size
)
{
if
(
dx
!=
nullptr
)
{
dx
[
col
]
=
dout
[
col
];
}
dy
[
col
]
=
-
dout
[
col
];
col
+=
blockDim
.
x
*
gridDim
.
x
;
}
}
template
<
typename
T
>
void
default_elementwise_sub_grad
(
const
GPUContext
&
ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
DenseTensor
&
out
,
const
DenseTensor
&
dout
,
DenseTensor
*
dx
,
DenseTensor
*
dy
,
int
axis
=
-
1
)
{
auto
*
dout_data
=
dout
.
data
<
T
>
();
// dx
if
(
dx
!=
nullptr
)
{
auto
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
dx
->
dims
()
==
dout
.
dims
())
{
if
(
dx_data
!=
dout_data
)
{
pten
::
Copy
(
ctx
,
dout
,
false
,
dx
);
}
}
else
{
// For inplace strategy, dx will be stored in addr of dout, which makes
// the result of dy wrong.
if
(
dx
->
IsSharedBufferWith
(
dout
))
{
dx
->
clear
();
dx
->
mutable_data
<
T
>
(
x
.
dims
(),
ctx
.
GetPlace
());
}
std
::
vector
<
int
>
reduce_dims
=
funcs
::
GetReduceDim
(
x
.
dims
(),
out
.
dims
(),
axis
);
gpuStream_t
stream
=
ctx
.
stream
();
kernels
::
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
T
>>
(
dout
,
dx
,
kps
::
IdentityFunctor
<
T
>
(),
reduce_dims
,
stream
);
}
}
// dy
if
(
dy
!=
nullptr
)
{
auto
*
dy_data
=
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
dy
->
dims
()
==
dout
.
dims
())
{
if
(
dy_data
!=
dout_data
)
{
dim3
block_size
=
dim3
(
PREDEFINED_BLOCK_SIZE
,
1
);
auto
size
=
dy
->
numel
();
dim3
grid_size
=
dim3
((
size
+
PREDEFINED_BLOCK_SIZE
-
1
)
/
PREDEFINED_BLOCK_SIZE
,
1
);
SimpleElemwiseSubGradCUDAKernel
<
T
><<<
grid_size
,
block_size
,
0
,
ctx
.
stream
()
>>>
(
dout
.
data
<
T
>
(),
size
,
nullptr
,
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
}
else
{
std
::
vector
<
int
>
reduce_dims
=
funcs
::
GetReduceDim
(
y
.
dims
(),
out
.
dims
(),
axis
);
gpuStream_t
stream
=
ctx
.
stream
();
kernels
::
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
InverseFunctor
<
T
>>
(
dout
,
dy
,
kps
::
InverseFunctor
<
T
>
(),
reduce_dims
,
stream
);
}
}
}
template
<
typename
T
>
void
elementwise_sub_grad
(
const
GPUContext
&
ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
DenseTensor
&
out
,
const
DenseTensor
&
dout
,
DenseTensor
*
dx
,
DenseTensor
*
dy
)
{
dim3
block_size
=
dim3
(
PREDEFINED_BLOCK_SIZE
,
1
);
auto
size
=
x
.
numel
();
dim3
grid_size
=
dim3
((
size
+
PREDEFINED_BLOCK_SIZE
-
1
)
/
PREDEFINED_BLOCK_SIZE
,
1
);
SimpleElemwiseSubGradCUDAKernel
<
T
><<<
grid_size
,
block_size
,
0
,
ctx
.
stream
()
>>>
(
dout
.
data
<
T
>
(),
size
,
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
}
// namespace pten
}
// namespace pten
paddle/pten/kernels/gpu/elementwise_grad_kernel.cu
浏览文件 @
7a1e1193
...
@@ -82,6 +82,42 @@ void AddTripleGradKernel(const Context& dev_ctx,
...
@@ -82,6 +82,42 @@ void AddTripleGradKernel(const Context& dev_ctx,
dev_ctx
,
ddx
,
ddy
,
d_ddout
,
axis
,
d_ddx
,
d_ddy
,
AddGradFunc
<
T
>
);
dev_ctx
,
ddx
,
ddy
,
d_ddout
,
axis
,
d_ddx
,
d_ddy
,
AddGradFunc
<
T
>
);
}
}
template
<
typename
T
,
typename
Context
>
void
SubtractGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
DenseTensor
&
dout
,
int
axis
,
DenseTensor
*
dx
,
DenseTensor
*
dy
)
{
// skip out
auto
*
out
=
&
dout
;
if
(
dx
!=
nullptr
&&
dy
!=
nullptr
&&
(
dx
->
dims
()
==
dy
->
dims
()))
{
elementwise_sub_grad
<
T
>
(
dev_ctx
,
x
,
y
,
*
out
,
dout
,
dx
,
dy
);
}
else
{
default_elementwise_sub_grad
<
T
>
(
dev_ctx
,
x
,
y
,
*
out
,
dout
,
dx
,
dy
,
axis
);
}
}
template
<
typename
T
,
typename
Context
>
void
SubtractDoubleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
y
,
paddle
::
optional
<
const
DenseTensor
&>
ddx
,
paddle
::
optional
<
const
DenseTensor
&>
ddy
,
const
DenseTensor
&
dout
,
int
axis
,
DenseTensor
*
ddout
)
{
pten
::
SubtractDoubleGradImpl
<
T
>
(
dev_ctx
,
y
,
ddx
,
ddy
,
dout
,
axis
,
ddout
,
ElementwiseCompute
<
funcs
::
SubtractFunctor
<
T
>
,
T
>
);
}
}
// namespace pten
}
// namespace pten
PT_REGISTER_KERNEL
(
add_grad
,
PT_REGISTER_KERNEL
(
add_grad
,
...
@@ -119,3 +155,27 @@ PT_REGISTER_KERNEL(add_triple_grad,
...
@@ -119,3 +155,27 @@ PT_REGISTER_KERNEL(add_triple_grad,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_KERNEL
(
subtract_grad
,
GPU
,
ALL_LAYOUT
,
pten
::
SubtractGradKernel
,
float
,
double
,
int
,
int64_t
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_KERNEL
(
subtract_double_grad
,
GPU
,
ALL_LAYOUT
,
pten
::
SubtractDoubleGradKernel
,
float
,
double
,
int
,
int64_t
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle/pten/kernels/impl/elementwise_grad_kernel_impl.h
浏览文件 @
7a1e1193
...
@@ -85,4 +85,27 @@ void AddDoubleGradImpl(const Context& dev_ctx,
...
@@ -85,4 +85,27 @@ void AddDoubleGradImpl(const Context& dev_ctx,
}
}
}
}
template
<
typename
T
,
typename
Context
,
typename
GradFunc
>
void
SubtractDoubleGradImpl
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
y
,
const
paddle
::
optional
<
const
DenseTensor
&>&
ddx
,
const
paddle
::
optional
<
const
DenseTensor
&>&
ddy
,
const
DenseTensor
&
dout
,
int
axis
,
DenseTensor
*
ddout
,
GradFunc
grad_func
)
{
// DDOut = ddx - ddy
if
(
ddout
)
{
DenseTensor
ddx_safe
,
ddy_safe
;
funcs
::
GetDoubleGradSafeTensor
<
Context
,
T
>
(
dev_ctx
,
dout
,
ddx
.
get_ptr
(),
&
ddx_safe
);
funcs
::
GetDoubleGradSafeTensor
<
Context
,
T
>
(
dev_ctx
,
y
,
ddy
.
get_ptr
(),
&
ddy_safe
);
ddout
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
grad_func
(
dev_ctx
,
ddx_safe
,
ddy_safe
,
axis
,
funcs
::
SubtractFunctor
<
T
>
(),
ddout
);
}
}
}
// namespace pten
}
// namespace pten
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录