Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
452c75b8
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
452c75b8
编写于
3月 09, 2022
作者:
Y
YuanRisheng
提交者:
GitHub
3月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move elementwise mul grad (#40252)
上级
0604df9e
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
539 addition
and
401 deletion
+539
-401
paddle/fluid/framework/new_executor/standalone_executor_test.cc
.../fluid/framework/new_executor/standalone_executor_test.cc
+1
-1
paddle/fluid/operators/elementwise/elementwise_functor.h
paddle/fluid/operators/elementwise/elementwise_functor.h
+0
-41
paddle/fluid/operators/elementwise/elementwise_mul_op.cc
paddle/fluid/operators/elementwise/elementwise_mul_op.cc
+0
-49
paddle/fluid/operators/elementwise/elementwise_mul_op.cu
paddle/fluid/operators/elementwise/elementwise_mul_op.cu
+0
-68
paddle/fluid/operators/elementwise/elementwise_mul_op.h
paddle/fluid/operators/elementwise/elementwise_mul_op.h
+0
-238
paddle/phi/kernels/cpu/elementwise_grad_kernel.cc
paddle/phi/kernels/cpu/elementwise_grad_kernel.cc
+57
-4
paddle/phi/kernels/elementwise_grad_kernel.h
paddle/phi/kernels/elementwise_grad_kernel.h
+39
-0
paddle/phi/kernels/funcs/elementwise_functor.h
paddle/phi/kernels/funcs/elementwise_functor.h
+44
-0
paddle/phi/kernels/gpu/elementwise_grad.h
paddle/phi/kernels/gpu/elementwise_grad.h
+37
-0
paddle/phi/kernels/gpu/elementwise_grad_kernel.cu
paddle/phi/kernels/gpu/elementwise_grad_kernel.cu
+54
-0
paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h
paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h
+273
-0
paddle/phi/ops/compat/elementwise_sig.cc
paddle/phi/ops/compat/elementwise_sig.cc
+34
-0
未找到文件。
paddle/fluid/framework/new_executor/standalone_executor_test.cc
浏览文件 @
452c75b8
...
...
@@ -46,7 +46,7 @@ USE_OP(matmul_grad);
USE_OP
(
square
);
USE_OP
(
transpose2_grad
);
USE_OP
(
concat_grad
);
USE_OP
(
elementwise_mul_grad
);
USE_OP
_ITSELF
(
elementwise_mul_grad
);
USE_OP
(
sigmoid_grad
);
USE_OP
(
tanh_grad
);
USE_OP
(
sum
);
...
...
paddle/fluid/operators/elementwise/elementwise_functor.h
浏览文件 @
452c75b8
...
...
@@ -196,47 +196,6 @@ struct MinGradXYFunctor {
}
};
template
<
typename
T
>
struct
MulGradFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
a
,
const
T
b
)
const
{
return
a
*
b
;
}
};
template
<
typename
T
>
struct
MulGradFunctor
<
Complex
<
T
>>
{
inline
HOSTDEVICE
Complex
<
T
>
operator
()(
const
Complex
<
T
>
a
,
const
Complex
<
T
>
b
)
const
{
Complex
<
T
>
b_conj
(
b
.
real
,
-
b
.
imag
);
return
a
*
b_conj
;
}
};
template
<
typename
InT
,
typename
OutT
>
struct
MulGradXYFunctor
{
inline
HOSTDEVICE
phi
::
Array
<
OutT
,
2
>
operator
()(
const
InT
a
,
const
InT
b
,
const
InT
c
)
{
phi
::
Array
<
OutT
,
2
>
outs
;
// dx = dout * y
outs
[
0
]
=
a
*
b
;
// dy = dout * x
outs
[
1
]
=
a
*
c
;
return
outs
;
}
};
template
<
typename
InT
,
typename
OutT
>
struct
MulGradXYFunctor
<
Complex
<
InT
>
,
Complex
<
OutT
>>
{
inline
HOSTDEVICE
phi
::
Array
<
Complex
<
OutT
>
,
2
>
operator
()(
const
Complex
<
InT
>
a
,
const
Complex
<
InT
>
b
,
const
Complex
<
InT
>
c
)
{
phi
::
Array
<
Complex
<
OutT
>
,
2
>
outs
;
// dx = dout * y
Complex
<
InT
>
b_conj
(
b
.
real
,
-
b
.
imag
);
outs
[
0
]
=
a
*
b_conj
;
// dy = dout * x
Complex
<
InT
>
c_conj
(
c
.
real
,
-
c
.
imag
);
outs
[
1
]
=
a
*
c_conj
;
return
outs
;
}
};
// Ternary compare
template
<
typename
T
>
struct
MaxGradXFunctor
{
...
...
paddle/fluid/operators/elementwise/elementwise_mul_op.cc
浏览文件 @
452c75b8
...
...
@@ -173,55 +173,6 @@ REGISTER_OP_CPU_KERNEL(
paddle
::
platform
::
complex
<
float
>>
,
ops
::
ElementwiseMulKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
double
>>
);
REGISTER_OP_CPU_KERNEL
(
elementwise_mul_grad
,
ops
::
ElementwiseMulGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ElementwiseMulGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
ElementwiseMulGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
ElementwiseMulGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
,
ops
::
ElementwiseMulGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
bool
>
,
ops
::
ElementwiseMulGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
bfloat16
>
,
ops
::
ElementwiseMulGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
ElementwiseMulGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
double
>>
);
REGISTER_OP_CPU_KERNEL
(
elementwise_mul_grad_grad
,
ops
::
ElementwiseMulDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ElementwiseMulDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
ElementwiseMulDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
ElementwiseMulDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
,
ops
::
ElementwiseMulDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
bool
>
,
ops
::
ElementwiseMulDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
bfloat16
>
,
ops
::
ElementwiseMulDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
ElementwiseMulDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
double
>>
);
REGISTER_OP_CPU_KERNEL
(
elementwise_mul_triple_grad
,
ops
::
ElementwiseMulTripleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ElementwiseMulTripleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
ElementwiseMulTripleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
ElementwiseMulTripleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
,
ops
::
ElementwiseMulTripleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
bool
>
,
ops
::
ElementwiseMulTripleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
bfloat16
>
,
ops
::
ElementwiseMulTripleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
ElementwiseMulTripleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
double
>>
);
REGISTER_OP_VERSION
(
elementwise_mul
)
.
AddCheckpoint
(
...
...
paddle/fluid/operators/elementwise/elementwise_mul_op.cu
浏览文件 @
452c75b8
...
...
@@ -63,33 +63,6 @@ class ElementwiseMulKernel<platform::CUDADeviceContext, T>
}
};
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CUDADeviceContext
>::
value
>::
type
ElementwiseMulGrad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
const
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
const
auto
place
=
ctx
.
GetPlace
();
if
(
dx
!=
nullptr
&&
dy
!=
nullptr
)
{
std
::
vector
<
const
framework
::
Tensor
*>
ins
=
{
dout
,
y
,
x
};
GetGradXAndYOut
<
ElementwiseType
::
kTernary
,
T
>
(
dev_ctx
,
place
,
axis
,
ins
,
dout
,
dx
,
dy
,
MulGradXYFunctor
<
T
,
T
>
());
}
else
if
(
dx
!=
nullptr
&&
dy
==
nullptr
)
{
std
::
vector
<
const
framework
::
Tensor
*>
ins
=
{
dout
,
y
};
GetGradXOrYOut
<
ElementwiseType
::
kBinary
,
T
>
(
dev_ctx
,
place
,
axis
,
ins
,
dout
,
dx
,
MulGradFunctor
<
T
>
());
}
else
if
(
dx
==
nullptr
&&
dy
!=
nullptr
)
{
std
::
vector
<
const
framework
::
Tensor
*>
ins
=
{
dout
,
x
};
GetGradXOrYOut
<
ElementwiseType
::
kBinary
,
T
>
(
dev_ctx
,
place
,
axis
,
ins
,
dout
,
dy
,
MulGradFunctor
<
T
>
());
}
}
}
// namespace operators
}
// namespace paddle
...
...
@@ -103,44 +76,3 @@ REGISTER_OP_CUDA_KERNEL(
ops
::
ElementwiseMulKernel
<
plat
::
CUDADeviceContext
,
plat
::
bfloat16
>
,
ops
::
ElementwiseMulKernel
<
plat
::
CUDADeviceContext
,
plat
::
complex
<
float
>>
,
ops
::
ElementwiseMulKernel
<
plat
::
CUDADeviceContext
,
plat
::
complex
<
double
>>
);
REGISTER_OP_CUDA_KERNEL
(
elementwise_mul_grad
,
ops
::
ElementwiseMulGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
ElementwiseMulGradKernel
<
plat
::
CUDADeviceContext
,
double
>
,
ops
::
ElementwiseMulGradKernel
<
plat
::
CUDADeviceContext
,
int
>
,
ops
::
ElementwiseMulGradKernel
<
plat
::
CUDADeviceContext
,
int64_t
>
,
ops
::
ElementwiseMulGradKernel
<
plat
::
CUDADeviceContext
,
bool
>
,
ops
::
ElementwiseMulGradKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
ElementwiseMulGradKernel
<
plat
::
CUDADeviceContext
,
plat
::
bfloat16
>
,
ops
::
ElementwiseMulGradKernel
<
plat
::
CUDADeviceContext
,
plat
::
complex
<
float
>>
,
ops
::
ElementwiseMulGradKernel
<
plat
::
CUDADeviceContext
,
plat
::
complex
<
double
>>
);
REGISTER_OP_CUDA_KERNEL
(
elementwise_mul_grad_grad
,
ops
::
ElementwiseMulDoubleGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
ElementwiseMulDoubleGradKernel
<
plat
::
CUDADeviceContext
,
double
>
,
ops
::
ElementwiseMulDoubleGradKernel
<
plat
::
CUDADeviceContext
,
int
>
,
ops
::
ElementwiseMulDoubleGradKernel
<
plat
::
CUDADeviceContext
,
int64_t
>
,
ops
::
ElementwiseMulDoubleGradKernel
<
plat
::
CUDADeviceContext
,
bool
>
,
ops
::
ElementwiseMulDoubleGradKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
ElementwiseMulDoubleGradKernel
<
plat
::
CUDADeviceContext
,
plat
::
bfloat16
>
,
ops
::
ElementwiseMulDoubleGradKernel
<
plat
::
CUDADeviceContext
,
plat
::
complex
<
float
>>
,
ops
::
ElementwiseMulDoubleGradKernel
<
plat
::
CUDADeviceContext
,
plat
::
complex
<
double
>>
);
REGISTER_OP_CUDA_KERNEL
(
elementwise_mul_triple_grad
,
ops
::
ElementwiseMulTripleGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
ElementwiseMulTripleGradKernel
<
plat
::
CUDADeviceContext
,
double
>
,
ops
::
ElementwiseMulTripleGradKernel
<
plat
::
CUDADeviceContext
,
int
>
,
ops
::
ElementwiseMulTripleGradKernel
<
plat
::
CUDADeviceContext
,
int64_t
>
,
ops
::
ElementwiseMulTripleGradKernel
<
plat
::
CUDADeviceContext
,
bool
>
,
ops
::
ElementwiseMulTripleGradKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
ElementwiseMulTripleGradKernel
<
plat
::
CUDADeviceContext
,
plat
::
bfloat16
>
,
ops
::
ElementwiseMulTripleGradKernel
<
plat
::
CUDADeviceContext
,
plat
::
complex
<
float
>>
,
ops
::
ElementwiseMulTripleGradKernel
<
plat
::
CUDADeviceContext
,
plat
::
complex
<
double
>>
);
paddle/fluid/operators/elementwise/elementwise_mul_op.h
浏览文件 @
452c75b8
...
...
@@ -137,244 +137,6 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
}
}
};
template
<
typename
T
>
struct
MulGradDX
{
HOSTDEVICE
T
operator
()(
T
x
,
T
y
,
T
out
,
T
dout
)
const
{
return
dout
*
y
;
}
};
template
<
typename
T
>
struct
MulGradDX
<
paddle
::
platform
::
complex
<
T
>>
{
HOSTDEVICE
paddle
::
platform
::
complex
<
T
>
operator
()(
paddle
::
platform
::
complex
<
T
>
x
,
paddle
::
platform
::
complex
<
T
>
y
,
paddle
::
platform
::
complex
<
T
>
out
,
paddle
::
platform
::
complex
<
T
>
dout
)
const
{
paddle
::
platform
::
complex
<
T
>
y_conj
(
y
.
real
,
-
y
.
imag
);
return
dout
*
y_conj
;
}
};
template
<
typename
T
>
struct
MulGradDY
{
HOSTDEVICE
T
operator
()(
T
x
,
T
y
,
T
out
,
T
dout
)
const
{
return
dout
*
x
;
}
};
template
<
typename
T
>
struct
MulGradDY
<
paddle
::
platform
::
complex
<
T
>>
{
HOSTDEVICE
paddle
::
platform
::
complex
<
T
>
operator
()(
paddle
::
platform
::
complex
<
T
>
x
,
paddle
::
platform
::
complex
<
T
>
y
,
paddle
::
platform
::
complex
<
T
>
out
,
paddle
::
platform
::
complex
<
T
>
dout
)
const
{
paddle
::
platform
::
complex
<
T
>
x_conj
(
x
.
real
,
-
x
.
imag
);
return
dout
*
x_conj
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
>::
type
ElementwiseMulGrad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
ElemwiseGradCompute
<
DeviceContext
,
T
,
MulGradDX
<
T
>
,
MulGradDY
<
T
>>
(
ctx
,
*
x
,
*
y
,
*
out
,
*
dout
,
axis
,
dx
,
dy
,
MulGradDX
<
T
>
(),
MulGradDY
<
T
>
());
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CUDADeviceContext
>::
value
>::
type
ElementwiseMulGrad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
);
#endif
template
<
typename
DeviceContext
,
typename
T
>
class
ElementwiseMulGradKernel
:
public
ElemwiseGradKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
ElemwiseGradKernel
<
T
>::
Compute
(
ctx
);
using
Tensor
=
framework
::
Tensor
;
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
y
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
auto
*
dout
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
out
=
dout
;
// out is not necessary
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
ElementwiseMulGrad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
ElementwiseMulDoubleGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
using
Tensor
=
framework
::
Tensor
;
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
y
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
auto
*
dout
=
ctx
.
Input
<
Tensor
>
(
"DOut"
);
auto
*
ddx
=
ctx
.
Input
<
Tensor
>
(
"DDX"
);
auto
*
ddy
=
ctx
.
Input
<
Tensor
>
(
"DDY"
);
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
*
ddout
=
ctx
.
Output
<
Tensor
>
(
"DDOut"
);
if
(
ddout
)
ddout
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
Tensor
ddx_safe
,
ddy_safe
;
GetDoubleGradSafeTensor
<
DeviceContext
,
T
>
(
ctx
,
x
,
ddx
,
&
ddx_safe
);
GetDoubleGradSafeTensor
<
DeviceContext
,
T
>
(
ctx
,
y
,
ddy
,
&
ddy_safe
);
// dx = dout * ddy
// dy = dout * ddx
// ddout = ddx * y + x * ddy
// change computation sequence to save memory, so ddout can inplace ddx and
// dx can be used as 'tmp' tensor
// (1) dx = x * ddy
// (2) dy = dout * ddx
// (3) ddout = ddx * y
// (4) ddout = ddout + dx
// (5) dx = dout * ddy
if
(
ddout
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
// size(ddout) > size(ddx), ddout can't use memory of ddx using inplace
if
(
ddout
->
numel
()
>
ddx
->
numel
())
{
ElemwiseGradCompute
<
DeviceContext
,
T
,
MulGradDX
<
T
>
,
MulGradDY
<
T
>>
(
ctx
,
ddx_safe
,
ddy_safe
,
*
dout
,
*
dout
,
axis
,
dx
,
dy
,
MulGradDX
<
T
>
(),
MulGradDY
<
T
>
());
Tensor
ddout_tmp
;
ddout_tmp
.
mutable_data
<
T
>
(
ddout
->
dims
(),
ctx
.
GetPlace
());
default_elementwise_mul
<
DeviceContext
,
T
>
(
ctx
,
y
,
&
ddx_safe
,
ddout
);
default_elementwise_mul
<
DeviceContext
,
T
>
(
ctx
,
&
ddy_safe
,
x
,
&
ddout_tmp
);
auto
ddout_t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
ddout
);
auto
ddout_tmp_t
=
framework
::
EigenVector
<
T
>::
Flatten
(
ddout_tmp
);
ddout_t
.
device
(
place
)
=
ddout_t
+
ddout_tmp_t
;
}
else
{
// use dx to save memory, other than alloc tmp tensor
Tensor
*
ddout_tmp
=
dx
;
default_elementwise_mul
<
DeviceContext
,
T
>
(
ctx
,
x
,
&
ddy_safe
,
ddout_tmp
);
// NOTE: in the following ElemwiseGradCompute, for the
// first output tensor is nullptr, the branch to calculate first
// output tensor will not be activated, DivGradDx function will not
// be called and can be ignored, the first branch has little effect
// on running speed.
ElemwiseGradCompute
<
DeviceContext
,
T
,
MulGradDX
<
T
>
,
MulGradDY
<
T
>>
(
ctx
,
ddx_safe
,
ddy_safe
,
*
dout
,
*
dout
,
axis
,
nullptr
,
dy
,
MulGradDX
<
T
>
(),
MulGradDY
<
T
>
());
default_elementwise_mul
<
DeviceContext
,
T
>
(
ctx
,
&
ddx_safe
,
y
,
ddout
);
auto
ddout_t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
ddout
);
auto
ddout_tmp_t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
ddout_tmp
);
ddout_t
.
device
(
place
)
=
ddout_t
+
ddout_tmp_t
;
default_elementwise_mul
<
DeviceContext
,
T
>
(
ctx
,
dout
,
&
ddy_safe
,
dx
);
}
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
ElementwiseMulTripleGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
using
Tensor
=
framework
::
Tensor
;
// get input
auto
*
x
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
y
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Y"
);
auto
*
dout
=
ctx
.
Input
<
framework
::
Tensor
>
(
"DOut"
);
auto
*
ddx
=
ctx
.
Input
<
framework
::
Tensor
>
(
"DDX"
);
auto
*
ddy
=
ctx
.
Input
<
framework
::
Tensor
>
(
"DDY"
);
auto
*
d_dx
=
ctx
.
Input
<
framework
::
Tensor
>
(
"D_DX"
);
auto
*
d_dy
=
ctx
.
Input
<
framework
::
Tensor
>
(
"D_DY"
);
auto
*
d_ddout
=
ctx
.
Input
<
framework
::
Tensor
>
(
"D_DDOut"
);
// get output
auto
*
out_d_x
=
ctx
.
Output
<
framework
::
Tensor
>
(
"D_X"
);
auto
*
out_d_y
=
ctx
.
Output
<
framework
::
Tensor
>
(
"D_Y"
);
auto
*
out_d_dout
=
ctx
.
Output
<
framework
::
Tensor
>
(
"D_DOut"
);
auto
*
out_d_ddx
=
ctx
.
Output
<
framework
::
Tensor
>
(
"D_DDX"
);
auto
*
out_d_ddy
=
ctx
.
Output
<
framework
::
Tensor
>
(
"D_DDY"
);
if
(
out_d_x
)
out_d_x
->
mutable_data
<
T
>
(
x
->
dims
(),
ctx
.
GetPlace
());
if
(
out_d_y
)
out_d_y
->
mutable_data
<
T
>
(
y
->
dims
(),
ctx
.
GetPlace
());
if
(
out_d_dout
)
out_d_dout
->
mutable_data
<
T
>
(
dout
->
dims
(),
ctx
.
GetPlace
());
if
(
out_d_ddx
)
out_d_ddx
->
mutable_data
<
T
>
(
x
->
dims
(),
ctx
.
GetPlace
());
if
(
out_d_ddy
)
out_d_ddy
->
mutable_data
<
T
>
(
y
->
dims
(),
ctx
.
GetPlace
());
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
Tensor
ddx_safe
,
ddy_safe
;
GetDoubleGradSafeTensor
<
DeviceContext
,
T
>
(
ctx
,
x
,
ddx
,
&
ddx_safe
);
GetDoubleGradSafeTensor
<
DeviceContext
,
T
>
(
ctx
,
y
,
ddy
,
&
ddy_safe
);
if
(
d_ddout
)
{
if
(
out_d_x
)
{
// out_d_x = ddy * d_ddout
default_elementwise_mul
<
DeviceContext
,
T
>
(
ctx
,
&
ddy_safe
,
d_ddout
,
out_d_x
);
}
if
(
out_d_y
)
{
// out_d_y = ddx * d_ddout
default_elementwise_mul
<
DeviceContext
,
T
>
(
ctx
,
&
ddx_safe
,
d_ddout
,
out_d_y
);
}
}
if
(
out_d_dout
)
{
// get out_d_dout
// out_d_dout = ddy * d_dx + d_dy * ddx
Tensor
out_d_dout_tmp
;
out_d_dout_tmp
.
mutable_data
<
T
>
(
dout
->
dims
(),
ctx
.
GetPlace
());
default_elementwise_mul
<
DeviceContext
,
T
>
(
ctx
,
d_dy
,
&
ddx_safe
,
out_d_dout
);
default_elementwise_mul
<
DeviceContext
,
T
>
(
ctx
,
&
ddy_safe
,
d_dx
,
&
out_d_dout_tmp
);
auto
out_d_dout_t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out_d_dout
);
auto
out_d_dout_tmp_t
=
framework
::
EigenVector
<
T
>::
Flatten
(
out_d_dout_tmp
);
out_d_dout_t
.
device
(
place
)
=
out_d_dout_t
+
out_d_dout_tmp_t
;
}
if
(
out_d_ddx
)
{
// get out_d_ddx
// out_d_ddx = dout * d_dy + y * d_ddout
Tensor
out_d_ddx_tmp
;
out_d_ddx_tmp
.
mutable_data
<
T
>
(
ddx
->
dims
(),
ctx
.
GetPlace
());
default_elementwise_mul
<
DeviceContext
,
T
>
(
ctx
,
dout
,
d_dy
,
out_d_ddx
);
default_elementwise_mul
<
DeviceContext
,
T
>
(
ctx
,
y
,
d_ddout
,
&
out_d_ddx_tmp
);
auto
out_d_ddx_t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out_d_ddx
);
auto
out_d_ddx_tmp_t
=
framework
::
EigenVector
<
T
>::
Flatten
(
out_d_ddx_tmp
);
out_d_ddx_t
.
device
(
place
)
=
out_d_ddx_t
+
out_d_ddx_tmp_t
;
}
if
(
out_d_ddy
)
{
// get out_d_ddy
// out_d_ddy = dout * d_dx + x * d_ddout
Tensor
out_d_ddy_tmp
;
out_d_ddy_tmp
.
mutable_data
<
T
>
(
ddy
->
dims
(),
ctx
.
GetPlace
());
default_elementwise_mul
<
DeviceContext
,
T
>
(
ctx
,
dout
,
d_dx
,
out_d_ddy
);
default_elementwise_mul
<
DeviceContext
,
T
>
(
ctx
,
x
,
d_ddout
,
&
out_d_ddy_tmp
);
auto
out_d_ddy_t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out_d_ddy
);
auto
out_d_ddy_tmp_t
=
framework
::
EigenVector
<
T
>::
Flatten
(
out_d_ddy_tmp
);
out_d_ddy_t
.
device
(
place
)
=
out_d_ddy_t
+
out_d_ddy_tmp_t
;
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/phi/kernels/cpu/elementwise_grad_kernel.cc
浏览文件 @
452c75b8
...
...
@@ -121,6 +121,20 @@ void DivideGradKernel(const Context& dev_ctx,
dev_ctx
,
x
,
y
,
out
,
dout
,
axis
,
dx
,
dy
,
DivGradDX
<
T
>
(),
DivGradDY
<
T
>
());
}
template
<
typename
T
,
typename
Context
>
void
MultiplyGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
DenseTensor
&
dout
,
int
axis
,
DenseTensor
*
dx
,
DenseTensor
*
dy
)
{
funcs
::
ElementwiseGradPreProcess
(
dout
,
dx
);
auto
*
out
=
&
dout
;
// out is not necessary
phi
::
funcs
::
ElemwiseGradCompute
<
Context
,
T
,
MulGradDX
<
T
>
,
MulGradDY
<
T
>>
(
dev_ctx
,
x
,
y
,
*
out
,
dout
,
axis
,
dx
,
dy
,
MulGradDX
<
T
>
(),
MulGradDY
<
T
>
());
}
}
// namespace phi
PD_REGISTER_KERNEL
(
add_grad
,
...
...
@@ -193,8 +207,8 @@ PD_REGISTER_KERNEL(divide_grad,
double
,
int
,
int64_t
,
p
addle
::
platform
::
complex
<
float
>
,
p
addle
::
platform
::
complex
<
double
>
)
{}
p
hi
::
dtype
::
complex
<
float
>
,
p
hi
::
dtype
::
complex
<
double
>
)
{}
PD_REGISTER_KERNEL
(
divide_double_grad
,
CPU
,
...
...
@@ -204,5 +218,44 @@ PD_REGISTER_KERNEL(divide_double_grad,
double
,
int
,
int64_t
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
phi
::
dtype
::
complex
<
float
>
,
phi
::
dtype
::
complex
<
double
>
)
{}
PD_REGISTER_KERNEL
(
multiply_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
MultiplyGradKernel
,
float
,
double
,
int
,
int64_t
,
bool
,
phi
::
dtype
::
bfloat16
,
phi
::
dtype
::
complex
<
float
>
,
phi
::
dtype
::
complex
<
double
>
)
{}
PD_REGISTER_KERNEL
(
multiply_double_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
MultiplyDoubleGradKernel
,
float
,
double
,
int
,
int64_t
,
bool
,
phi
::
dtype
::
bfloat16
,
phi
::
dtype
::
complex
<
float
>
,
phi
::
dtype
::
complex
<
double
>
)
{}
PD_REGISTER_KERNEL
(
multiply_triple_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
MultiplyTripleGradKernel
,
float
,
double
,
int
,
int64_t
,
bool
,
phi
::
dtype
::
bfloat16
,
phi
::
dtype
::
complex
<
float
>
,
phi
::
dtype
::
complex
<
double
>
)
{}
paddle/phi/kernels/elementwise_grad_kernel.h
浏览文件 @
452c75b8
...
...
@@ -85,4 +85,43 @@ void DivideDoubleGradKernel(const Context& dev_ctx,
DenseTensor
*
dy
,
DenseTensor
*
dout
,
DenseTensor
*
ddout
);
template
<
typename
T
,
typename
Context
>
void
MultiplyGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
DenseTensor
&
dout
,
int
axis
,
DenseTensor
*
dx
,
DenseTensor
*
dy
);
template
<
typename
T
,
typename
Context
>
void
MultiplyDoubleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
DenseTensor
&
dout
,
paddle
::
optional
<
const
DenseTensor
&>
ddx
,
paddle
::
optional
<
const
DenseTensor
&>
ddy
,
int
axis
,
DenseTensor
*
dx
,
DenseTensor
*
dy
,
DenseTensor
*
ddout
);
template
<
typename
T
,
typename
Context
>
void
MultiplyTripleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
DenseTensor
&
dout
,
paddle
::
optional
<
const
DenseTensor
&>
ddx
,
paddle
::
optional
<
const
DenseTensor
&>
ddy
,
const
DenseTensor
&
d_dx
,
const
DenseTensor
&
d_dy
,
paddle
::
optional
<
const
DenseTensor
&>
d_ddout
,
int
axis
,
DenseTensor
*
d_x
,
DenseTensor
*
d_y
,
DenseTensor
*
d_dout
,
DenseTensor
*
d_ddx
,
DenseTensor
*
d_ddy
);
}
// namespace phi
paddle/phi/kernels/funcs/elementwise_functor.h
浏览文件 @
452c75b8
...
...
@@ -160,5 +160,49 @@ struct DivGradYFunctor<ComplexType<T>> {
}
};
template
<
typename
T
>
struct
MultiplyGradFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
a
,
const
T
b
)
const
{
return
a
*
b
;
}
};
template
<
typename
T
>
struct
MultiplyGradFunctor
<
ComplexType
<
T
>>
{
inline
HOSTDEVICE
ComplexType
<
T
>
operator
()(
const
ComplexType
<
T
>
a
,
const
ComplexType
<
T
>
b
)
const
{
ComplexType
<
T
>
b_conj
(
b
.
real
,
-
b
.
imag
);
return
a
*
b_conj
;
}
};
template
<
typename
InT
,
typename
OutT
>
struct
MultiplyGradXYFunctor
{
inline
HOSTDEVICE
phi
::
Array
<
OutT
,
2
>
operator
()(
const
InT
a
,
const
InT
b
,
const
InT
c
)
{
phi
::
Array
<
OutT
,
2
>
outs
;
// dx = dout * y
outs
[
0
]
=
a
*
b
;
// dy = dout * x
outs
[
1
]
=
a
*
c
;
return
outs
;
}
};
template
<
typename
InT
,
typename
OutT
>
struct
MultiplyGradXYFunctor
<
ComplexType
<
InT
>
,
ComplexType
<
OutT
>>
{
inline
HOSTDEVICE
phi
::
Array
<
ComplexType
<
OutT
>
,
2
>
operator
()(
const
ComplexType
<
InT
>
a
,
const
ComplexType
<
InT
>
b
,
const
ComplexType
<
InT
>
c
)
{
phi
::
Array
<
ComplexType
<
OutT
>
,
2
>
outs
;
// dx = dout * y
ComplexType
<
InT
>
b_conj
(
b
.
real
,
-
b
.
imag
);
outs
[
0
]
=
a
*
b_conj
;
// dy = dout * x
ComplexType
<
InT
>
c_conj
(
c
.
real
,
-
c
.
imag
);
outs
[
1
]
=
a
*
c_conj
;
return
outs
;
}
};
}
// namespace funcs
}
// namespace phi
paddle/phi/kernels/gpu/elementwise_grad.h
浏览文件 @
452c75b8
...
...
@@ -360,4 +360,41 @@ void ElementwiseDivGrad(const GPUContext &dev_ctx,
}
}
/*
******************************
Mul Grad
******************************
*/
template
<
typename
T
>
void
ElementwiseMulGrad
(
const
GPUContext
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
DenseTensor
&
dout
,
DenseTensor
*
dx
,
DenseTensor
*
dy
,
int
axis
)
{
const
auto
place
=
dev_ctx
.
GetPlace
();
if
(
dx
!=
nullptr
&&
dy
!=
nullptr
)
{
std
::
vector
<
const
DenseTensor
*>
ins
=
{
&
dout
,
&
y
,
&
x
};
GetGradXAndYOut
<
ElementwiseType
::
kTernary
,
T
>
(
dev_ctx
,
place
,
axis
,
ins
,
dout
,
dx
,
dy
,
funcs
::
MultiplyGradXYFunctor
<
T
,
T
>
());
}
else
if
(
dx
!=
nullptr
&&
dy
==
nullptr
)
{
std
::
vector
<
const
DenseTensor
*>
ins
=
{
&
dout
,
&
y
};
GetGradXOrYOut
<
ElementwiseType
::
kBinary
,
T
>
(
dev_ctx
,
place
,
axis
,
ins
,
dout
,
dx
,
funcs
::
MultiplyGradFunctor
<
T
>
());
}
else
if
(
dx
==
nullptr
&&
dy
!=
nullptr
)
{
std
::
vector
<
const
DenseTensor
*>
ins
=
{
&
dout
,
&
x
};
GetGradXOrYOut
<
ElementwiseType
::
kBinary
,
T
>
(
dev_ctx
,
place
,
axis
,
ins
,
dout
,
dy
,
funcs
::
MultiplyGradFunctor
<
T
>
());
}
}
}
// namespace phi
paddle/phi/kernels/gpu/elementwise_grad_kernel.cu
浏览文件 @
452c75b8
...
...
@@ -136,6 +136,18 @@ void DivideGradKernel(const Context& dev_ctx,
}
}
template
<
typename
T
,
typename
Context
>
void
MultiplyGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
DenseTensor
&
dout
,
int
axis
,
DenseTensor
*
dx
,
DenseTensor
*
dy
)
{
funcs
::
ElementwiseGradPreProcess
(
dout
,
dx
);
ElementwiseMulGrad
<
T
>
(
dev_ctx
,
x
,
y
,
dout
,
dx
,
dy
,
axis
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
add_grad
,
...
...
@@ -228,3 +240,45 @@ PD_REGISTER_KERNEL(divide_double_grad,
int64_t
,
phi
::
dtype
::
complex
<
float
>
,
phi
::
dtype
::
complex
<
double
>
)
{}
PD_REGISTER_KERNEL
(
multiply_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
MultiplyGradKernel
,
float
,
phi
::
dtype
::
float16
,
double
,
int
,
int64_t
,
bool
,
phi
::
dtype
::
bfloat16
,
phi
::
dtype
::
complex
<
float
>
,
phi
::
dtype
::
complex
<
double
>
)
{}
PD_REGISTER_KERNEL
(
multiply_double_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
MultiplyDoubleGradKernel
,
float
,
phi
::
dtype
::
float16
,
double
,
int
,
int64_t
,
bool
,
phi
::
dtype
::
bfloat16
,
phi
::
dtype
::
complex
<
float
>
,
phi
::
dtype
::
complex
<
double
>
)
{}
PD_REGISTER_KERNEL
(
multiply_triple_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
MultiplyTripleGradKernel
,
float
,
phi
::
dtype
::
float16
,
double
,
int
,
int64_t
,
bool
,
phi
::
dtype
::
bfloat16
,
phi
::
dtype
::
complex
<
float
>
,
phi
::
dtype
::
complex
<
double
>
)
{}
paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h
浏览文件 @
452c75b8
...
...
@@ -259,4 +259,277 @@ void DivideDoubleGradKernel(const Context& dev_ctx,
}
}
template
<
typename
T
>
struct
MulGradDX
{
HOSTDEVICE
T
operator
()(
T
x
,
T
y
,
T
out
,
T
dout
)
const
{
return
dout
*
y
;
}
};
template
<
typename
T
>
struct
MulGradDX
<
phi
::
dtype
::
complex
<
T
>>
{
HOSTDEVICE
phi
::
dtype
::
complex
<
T
>
operator
()(
phi
::
dtype
::
complex
<
T
>
x
,
phi
::
dtype
::
complex
<
T
>
y
,
phi
::
dtype
::
complex
<
T
>
out
,
phi
::
dtype
::
complex
<
T
>
dout
)
const
{
phi
::
dtype
::
complex
<
T
>
y_conj
(
y
.
real
,
-
y
.
imag
);
return
dout
*
y_conj
;
}
};
/*
******************************
Multiply Grad
******************************
*/
template
<
typename
T
>
struct
MulGradDY
{
HOSTDEVICE
T
operator
()(
T
x
,
T
y
,
T
out
,
T
dout
)
const
{
return
dout
*
x
;
}
};
template
<
typename
T
>
struct
MulGradDY
<
phi
::
dtype
::
complex
<
T
>>
{
HOSTDEVICE
phi
::
dtype
::
complex
<
T
>
operator
()(
phi
::
dtype
::
complex
<
T
>
x
,
phi
::
dtype
::
complex
<
T
>
y
,
phi
::
dtype
::
complex
<
T
>
out
,
phi
::
dtype
::
complex
<
T
>
dout
)
const
{
phi
::
dtype
::
complex
<
T
>
x_conj
(
x
.
real
,
-
x
.
imag
);
return
dout
*
x_conj
;
}
};
template
<
typename
T
,
typename
Context
>
void
MultiplyDoubleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
DenseTensor
&
dout
,
paddle
::
optional
<
const
DenseTensor
&>
ddx
,
paddle
::
optional
<
const
DenseTensor
&>
ddy
,
int
axis
,
DenseTensor
*
dx
,
DenseTensor
*
dy
,
DenseTensor
*
ddout
)
{
if
(
ddout
)
dev_ctx
.
template
Alloc
<
T
>(
ddout
);
DenseTensor
ddx_safe
,
ddy_safe
;
funcs
::
GetDoubleGradSafeTensor
<
Context
,
T
>
(
dev_ctx
,
x
,
ddx
.
get_ptr
(),
&
ddx_safe
);
funcs
::
GetDoubleGradSafeTensor
<
Context
,
T
>
(
dev_ctx
,
y
,
ddy
.
get_ptr
(),
&
ddy_safe
);
// dx = dout * ddy
// dy = dout * ddx
// ddout = ddx * y + x * ddy
// change computation sequence to save memory, so ddout can inplace ddx and
// dx can be used as 'tmp' tensor
// (1) dx = x * ddy
// (2) dy = dout * ddx
// (3) ddout = ddx * y
// (4) ddout = ddout + dx
// (5) dx = dout * ddy
if
(
ddout
)
{
auto
&
place
=
*
dev_ctx
.
eigen_device
();
// size(ddout) > size(ddx), ddout can't use memory of ddx using inplace
if
(
ddout
->
numel
()
>
ddx
.
get_ptr
()
->
numel
())
{
phi
::
funcs
::
ElemwiseGradCompute
<
Context
,
T
,
MulGradDX
<
T
>
,
MulGradDY
<
T
>>
(
dev_ctx
,
ddx_safe
,
ddy_safe
,
dout
,
dout
,
axis
,
dx
,
dy
,
MulGradDX
<
T
>
(),
MulGradDY
<
T
>
());
DenseTensor
ddout_tmp
;
ddout_tmp
.
Resize
(
ddout
->
dims
());
dev_ctx
.
template
Alloc
<
T
>(
&
ddout_tmp
);
funcs
::
DefaultElementwiseOperator
<
Context
,
T
,
funcs
::
MultiplyFunctor
<
T
>
,
funcs
::
InverseMultiplyFunctor
<
T
>>
(
dev_ctx
,
y
,
ddx_safe
,
ddout
,
axis
);
funcs
::
DefaultElementwiseOperator
<
Context
,
T
,
funcs
::
MultiplyFunctor
<
T
>
,
funcs
::
InverseMultiplyFunctor
<
T
>>
(
dev_ctx
,
ddy_safe
,
x
,
&
ddout_tmp
,
axis
);
auto
ddout_t
=
phi
::
EigenVector
<
T
>::
Flatten
(
*
ddout
);
auto
ddout_tmp_t
=
phi
::
EigenVector
<
T
>::
Flatten
(
ddout_tmp
);
ddout_t
.
device
(
place
)
=
ddout_t
+
ddout_tmp_t
;
}
else
{
// use dx to save memory, other than alloc tmp tensor
DenseTensor
*
ddout_tmp
=
dx
;
funcs
::
DefaultElementwiseOperator
<
Context
,
T
,
funcs
::
MultiplyFunctor
<
T
>
,
funcs
::
InverseMultiplyFunctor
<
T
>>
(
dev_ctx
,
x
,
ddy_safe
,
ddout_tmp
,
axis
);
// NOTE: in the following ElemwiseGradCompute, for the
// first output tensor is nullptr, the branch to calculate first
// output tensor will not be activated, DivGradDx function will not
// be called and can be ignored, the first branch has little effect
// on running speed.
phi
::
funcs
::
ElemwiseGradCompute
<
Context
,
T
,
MulGradDX
<
T
>
,
MulGradDY
<
T
>>
(
dev_ctx
,
ddx_safe
,
ddy_safe
,
dout
,
dout
,
axis
,
nullptr
,
dy
,
MulGradDX
<
T
>
(),
MulGradDY
<
T
>
());
funcs
::
DefaultElementwiseOperator
<
Context
,
T
,
funcs
::
MultiplyFunctor
<
T
>
,
funcs
::
InverseMultiplyFunctor
<
T
>>
(
dev_ctx
,
ddx_safe
,
y
,
ddout
,
axis
);
auto
ddout_t
=
phi
::
EigenVector
<
T
>::
Flatten
(
*
ddout
);
auto
ddout_tmp_t
=
phi
::
EigenVector
<
T
>::
Flatten
(
*
ddout_tmp
);
ddout_t
.
device
(
place
)
=
ddout_t
+
ddout_tmp_t
;
funcs
::
DefaultElementwiseOperator
<
Context
,
T
,
funcs
::
MultiplyFunctor
<
T
>
,
funcs
::
InverseMultiplyFunctor
<
T
>>
(
dev_ctx
,
dout
,
ddy_safe
,
dx
,
axis
);
}
}
}
template
<
typename
T
,
typename
Context
>
void
MultiplyTripleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
DenseTensor
&
dout
,
paddle
::
optional
<
const
DenseTensor
&>
ddx
,
paddle
::
optional
<
const
DenseTensor
&>
ddy
,
const
DenseTensor
&
d_dx
,
const
DenseTensor
&
d_dy
,
paddle
::
optional
<
const
DenseTensor
&>
d_ddout
,
int
axis
,
DenseTensor
*
d_x
,
DenseTensor
*
d_y
,
DenseTensor
*
d_dout
,
DenseTensor
*
d_ddx
,
DenseTensor
*
d_ddy
)
{
if
(
d_x
)
{
d_x
->
Resize
(
x
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
d_x
);
}
if
(
d_y
)
{
d_y
->
Resize
(
y
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
d_y
);
}
if
(
d_dout
)
{
d_dout
->
Resize
(
dout
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
d_dout
);
}
if
(
d_ddx
)
{
d_ddx
->
Resize
(
x
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
d_ddx
);
}
if
(
d_ddy
)
{
d_ddy
->
Resize
(
y
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
d_ddy
);
}
auto
&
place
=
*
dev_ctx
.
eigen_device
();
DenseTensor
ddx_safe
,
ddy_safe
;
funcs
::
GetDoubleGradSafeTensor
<
Context
,
T
>
(
dev_ctx
,
x
,
ddx
.
get_ptr
(),
&
ddx_safe
);
funcs
::
GetDoubleGradSafeTensor
<
Context
,
T
>
(
dev_ctx
,
y
,
ddy
.
get_ptr
(),
&
ddy_safe
);
if
(
d_ddout
.
get_ptr
())
{
if
(
d_x
)
{
// d_x = ddy * d_ddout
funcs
::
DefaultElementwiseOperator
<
Context
,
T
,
funcs
::
MultiplyFunctor
<
T
>
,
funcs
::
InverseMultiplyFunctor
<
T
>>
(
dev_ctx
,
ddy_safe
,
*
(
d_ddout
.
get_ptr
()),
d_x
,
axis
);
}
if
(
d_y
)
{
// d_y = ddx * d_ddout
funcs
::
DefaultElementwiseOperator
<
Context
,
T
,
funcs
::
MultiplyFunctor
<
T
>
,
funcs
::
InverseMultiplyFunctor
<
T
>>
(
dev_ctx
,
ddx_safe
,
*
(
d_ddout
.
get_ptr
()),
d_y
,
axis
);
}
}
if
(
d_dout
)
{
// get d_dout
// d_dout = ddy * d_dx + d_dy * ddx
DenseTensor
d_dout_tmp
;
d_dout_tmp
.
Resize
(
dout
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
&
d_dout_tmp
);
funcs
::
DefaultElementwiseOperator
<
Context
,
T
,
funcs
::
MultiplyFunctor
<
T
>
,
funcs
::
InverseMultiplyFunctor
<
T
>>
(
dev_ctx
,
d_dy
,
ddx_safe
,
d_dout
,
axis
);
funcs
::
DefaultElementwiseOperator
<
Context
,
T
,
funcs
::
MultiplyFunctor
<
T
>
,
funcs
::
InverseMultiplyFunctor
<
T
>>
(
dev_ctx
,
ddy_safe
,
d_dx
,
&
d_dout_tmp
,
axis
);
auto
d_dout_t
=
phi
::
EigenVector
<
T
>::
Flatten
(
*
d_dout
);
auto
d_dout_tmp_t
=
phi
::
EigenVector
<
T
>::
Flatten
(
d_dout_tmp
);
d_dout_t
.
device
(
place
)
=
d_dout_t
+
d_dout_tmp_t
;
}
if
(
d_ddx
)
{
// get d_ddx
// d_ddx = dout * d_dy + y * d_ddout
DenseTensor
d_ddx_tmp
;
d_ddx_tmp
.
Resize
(
ddx
->
dims
());
dev_ctx
.
template
Alloc
<
T
>(
&
d_ddx_tmp
);
funcs
::
DefaultElementwiseOperator
<
Context
,
T
,
funcs
::
MultiplyFunctor
<
T
>
,
funcs
::
InverseMultiplyFunctor
<
T
>>
(
dev_ctx
,
dout
,
d_dy
,
d_ddx
,
axis
);
funcs
::
DefaultElementwiseOperator
<
Context
,
T
,
funcs
::
MultiplyFunctor
<
T
>
,
funcs
::
InverseMultiplyFunctor
<
T
>>
(
dev_ctx
,
y
,
*
(
d_ddout
.
get_ptr
()),
&
d_ddx_tmp
,
axis
);
auto
d_ddx_t
=
phi
::
EigenVector
<
T
>::
Flatten
(
*
d_ddx
);
auto
d_ddx_tmp_t
=
phi
::
EigenVector
<
T
>::
Flatten
(
d_ddx_tmp
);
d_ddx_t
.
device
(
place
)
=
d_ddx_t
+
d_ddx_tmp_t
;
}
if
(
d_ddy
)
{
// get d_ddy
// d_ddy = dout * d_dx + x * d_ddout
DenseTensor
d_ddy_tmp
;
d_ddy_tmp
.
Resize
(
ddy
->
dims
());
dev_ctx
.
template
Alloc
<
T
>(
&
d_ddy_tmp
);
funcs
::
DefaultElementwiseOperator
<
Context
,
T
,
funcs
::
MultiplyFunctor
<
T
>
,
funcs
::
InverseMultiplyFunctor
<
T
>>
(
dev_ctx
,
dout
,
d_dx
,
d_ddy
,
axis
);
funcs
::
DefaultElementwiseOperator
<
Context
,
T
,
funcs
::
MultiplyFunctor
<
T
>
,
funcs
::
InverseMultiplyFunctor
<
T
>>
(
dev_ctx
,
x
,
*
(
d_ddout
.
get_ptr
()),
&
d_ddy_tmp
,
axis
);
auto
d_ddy_t
=
phi
::
EigenVector
<
T
>::
Flatten
(
*
d_ddy
);
auto
d_ddy_tmp_t
=
phi
::
EigenVector
<
T
>::
Flatten
(
d_ddy_tmp
);
d_ddy_t
.
device
(
place
)
=
d_ddy_t
+
d_ddy_tmp_t
;
}
}
}
// namespace phi
paddle/phi/ops/compat/elementwise_sig.cc
浏览文件 @
452c75b8
...
...
@@ -122,6 +122,31 @@ KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
{
GradVarName
(
"Y"
),
"DOut"
,
"DDOut"
});
}
KernelSignature
ElementwiseMulGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"multiply_grad"
,
{
"X"
,
"Y"
,
GradVarName
(
"Out"
)},
{
"axis"
},
{
GradVarName
(
"X"
),
GradVarName
(
"Y"
)});
}
KernelSignature
ElementwiseMulDoubleGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"multiply_double_grad"
,
{
"X"
,
"Y"
,
"DOut"
,
"DDX"
,
"DDY"
},
{
"axis"
},
{
GradVarName
(
"X"
),
GradVarName
(
"Y"
),
"DDOut"
});
}
KernelSignature
ElementwiseMulTripleGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"multiply_triple_grad"
,
{
"X"
,
"Y"
,
"DOut"
,
"DDX"
,
"DDY"
,
"D_DX"
,
"D_DY"
,
"D_DDOut"
},
{
"axis"
},
{
"D_X"
,
"D_Y"
,
"D_DOut"
,
"D_DDX"
,
"D_DDY"
});
}
}
// namespace phi
PD_REGISTER_BASE_KERNEL_NAME
(
elementwise_add
,
add
);
...
...
@@ -135,6 +160,9 @@ PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad);
PD_REGISTER_BASE_KERNEL_NAME
(
elementwise_sub_grad_grad
,
subtract_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
elementwise_div_grad
,
divide_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
elementwise_div_grad_grad
,
divide_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
elementwise_mul_grad
,
multiply_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
elementwise_mul_grad_grad
,
multiply_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
elementwise_mul_triple_grad
,
multiply_triple_grad
);
PD_REGISTER_ARG_MAPPING_FN
(
elementwise_add
,
phi
::
ElementwiseAddOpArgumentMapping
);
...
...
@@ -158,3 +186,9 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad,
phi
::
ElementwiseDivGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
elementwise_div_grad_grad
,
phi
::
ElementwiseDivDoubleGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
elementwise_mul_grad
,
phi
::
ElementwiseMulGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
elementwise_mul_grad_grad
,
phi
::
ElementwiseMulDoubleGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
elementwise_mul_triple_grad
,
phi
::
ElementwiseMulTripleGradOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录