Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
befd6d53
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
befd6d53
编写于
12月 03, 2020
作者:
Z
Zhang Ting
提交者:
GitHub
12月 03, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
improve elementwise_add_grad perf (#29277)
* improve performance of elementwise_sum_grad
上级
ebf68919
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
332 addition
and
37 deletion
+332
-37
paddle/fluid/operators/elementwise/elementwise_add_op.cu
paddle/fluid/operators/elementwise/elementwise_add_op.cu
+304
-5
paddle/fluid/operators/elementwise/elementwise_add_op.h
paddle/fluid/operators/elementwise/elementwise_add_op.h
+28
-32
未找到文件。
paddle/fluid/operators/elementwise/elementwise_add_op.cu
浏览文件 @
befd6d53
...
@@ -11,12 +11,16 @@ distributed under the License is distributed on an "AS IS" BASIS,
...
@@ -11,12 +11,16 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <algorithm>
#include <functional>
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/float16.h"
#define WARPSIZE 32
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
namespace
plat
=
paddle
::
platform
;
...
@@ -74,11 +78,10 @@ static __global__ void SimpleElemwiseAddGradCUDAKernel(const T* dout,
...
@@ -74,11 +78,10 @@ static __global__ void SimpleElemwiseAddGradCUDAKernel(const T* dout,
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
plat
::
CUDADeviceContext
>::
value
>::
type
std
::
is_same
<
DeviceContext
,
plat
::
CUDADeviceContext
>::
value
>::
type
elementwise_add_grad
(
const
framework
::
ExecutionContext
&
ctx
,
ElementwiseAddGrad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
framework
::
Tensor
*
dy
)
{
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD_SIZE
,
1
);
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD_SIZE
,
1
);
auto
size
=
x
->
numel
();
auto
size
=
x
->
numel
();
dim3
grid_size
=
dim3
grid_size
=
...
@@ -90,6 +93,302 @@ elementwise_add_grad(const framework::ExecutionContext& ctx,
...
@@ -90,6 +93,302 @@ elementwise_add_grad(const framework::ExecutionContext& ctx,
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
}
inline
static
bool
UseReduceFirstAxisRank1
(
const
framework
::
DDim
&
dout_dims
,
const
framework
::
DDim
&
x_dims
,
const
framework
::
DDim
&
y_dims
,
const
int
axis
)
{
int
start_axis
=
(
axis
==
-
1
?
std
::
abs
(
x_dims
.
size
()
-
y_dims
.
size
())
:
axis
);
if
(
y_dims
[
y_dims
.
size
()
-
1
]
==
1
)
{
return
false
;
}
if
(
y_dims
.
size
()
>
1
)
{
for
(
int
i
=
0
;
i
<
y_dims
.
size
()
-
1
;
++
i
)
{
if
(
y_dims
[
i
]
!=
1
)
{
return
false
;
}
}
return
true
;
}
else
if
(
start_axis
==
x_dims
.
size
()
-
1
)
{
return
true
;
}
return
false
;
}
inline
static
bool
UseReduceFirstAxisRank2
(
const
framework
::
DDim
&
dout_dims
,
const
framework
::
DDim
&
x_dims
,
const
framework
::
DDim
&
y_dims
,
const
int
axis
)
{
int
start_axis
=
(
axis
==
-
1
?
std
::
abs
(
x_dims
.
size
()
-
y_dims
.
size
())
:
axis
);
if
(
y_dims
.
size
()
<
2
||
x_dims
[
x_dims
.
size
()
-
2
]
!=
y_dims
[
y_dims
.
size
()
-
2
]
||
x_dims
[
x_dims
.
size
()
-
1
]
!=
y_dims
[
y_dims
.
size
()
-
1
])
{
return
false
;
}
if
(
start_axis
==
x_dims
.
size
()
-
2
)
{
return
true
;
}
else
if
(
start_axis
==
0
)
{
for
(
int
i
=
0
;
i
<
y_dims
.
size
()
-
2
;
++
i
)
{
if
(
y_dims
[
i
]
!=
1
)
{
return
false
;
}
}
return
true
;
}
return
false
;
}
inline
static
bool
UseReduceSecondAxisRank2
(
const
framework
::
DDim
&
dout_dims
,
const
framework
::
DDim
&
x_dims
,
const
framework
::
DDim
&
y_dims
,
const
int
axis
,
int
*
start
,
int
*
end
)
{
if
(
x_dims
.
size
()
!=
y_dims
.
size
()
||
y_dims
.
size
()
<
3
)
{
return
false
;
}
auto
y_dims_vec
=
framework
::
vectorize
(
y_dims
);
auto
start_iter
=
std
::
find
(
y_dims_vec
.
begin
(),
y_dims_vec
.
end
(),
1
);
auto
end_iter
=
std
::
find
(
y_dims_vec
.
rbegin
(),
y_dims_vec
.
rend
(),
1
);
if
(
start_iter
==
y_dims_vec
.
end
()
||
start_iter
==
y_dims_vec
.
end
()
-
1
)
{
return
false
;
}
else
{
*
start
=
std
::
distance
(
y_dims_vec
.
begin
(),
start_iter
);
*
end
=
y_dims_vec
.
size
()
-
1
-
std
::
distance
(
y_dims_vec
.
rbegin
(),
end_iter
);
for
(
int
i
=
*
start
;
i
<=
*
end
;
++
i
)
{
if
(
y_dims
[
i
]
!=
1
)
{
return
false
;
}
}
return
true
;
}
}
template
<
typename
T
,
typename
OP
>
__global__
__launch_bounds__
(
1024
)
void
ReduceFirstAixsKernel
(
const
T
*
in
,
T
*
out
,
const
int64_t
num_rows
,
const
int64_t
num_cols
,
OP
op
,
T
init
)
{
int
row
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
T
sum
=
init
;
if
(
row
<
num_rows
&&
col
<
num_cols
)
sum
=
in
[
row
*
num_cols
+
col
];
__shared__
__align__
(
alignof
(
T
))
char
partial_sums_raw
[
WARPSIZE
*
(
WARPSIZE
+
1
)
*
sizeof
(
T
)];
T
*
partial_sums
=
reinterpret_cast
<
T
*>
(
partial_sums_raw
);
row
+=
gridDim
.
y
*
blockDim
.
y
;
if
(
col
<
num_cols
)
{
for
(;
row
<
num_rows
;
row
+=
gridDim
.
y
*
blockDim
.
y
)
{
sum
=
op
(
sum
,
in
[
row
*
num_cols
+
col
]);
}
}
partial_sums
[
threadIdx
.
x
*
(
WARPSIZE
+
1
)
+
threadIdx
.
y
]
=
sum
;
__syncthreads
();
if
(
threadIdx
.
y
==
0
&&
col
<
num_cols
)
{
T
s
=
partial_sums
[
threadIdx
.
x
*
(
WARPSIZE
+
1
)];
const
int
numRowsThisBlock
=
min
(
static_cast
<
int64_t
>
(
blockDim
.
y
),
num_rows
-
blockIdx
.
y
*
blockDim
.
y
);
for
(
int
row
=
1
;
row
<
numRowsThisBlock
;
++
row
)
{
T
t
=
partial_sums
[
threadIdx
.
x
*
(
WARPSIZE
+
1
)
+
row
];
s
=
op
(
s
,
t
);
}
out
[
col
*
gridDim
.
y
+
blockIdx
.
y
]
=
s
;
}
}
template
<
typename
DeviceContext
,
typename
T
>
static
void
ElemwiseYGradRank1CUDA
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
&
dout
,
const
int
rows
,
const
int
cols
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
dim3
block_dim
(
WARPSIZE
,
std
::
min
(
rows
,
1024
/
WARPSIZE
));
dim3
grid_dim
((
cols
+
(
WARPSIZE
-
1
))
/
WARPSIZE
,
1
,
1
);
if
(
dx
)
{
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
framework
::
TensorCopy
(
dout
,
ctx
.
GetPlace
(),
ctx
.
template
device_context
<
platform
::
DeviceContext
>(),
dx
);
}
if
(
dy
)
{
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
T
*
dout_data
=
dout
.
data
<
T
>
();
T
*
dy_data
=
dy
->
data
<
T
>
();
auto
stream
=
ctx
.
template
device_context
<
DeviceContext
>().
stream
();
ReduceFirstAixsKernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
dout_data
,
dy_data
,
rows
,
cols
,
AddFunctor
<
T
>
(),
static_cast
<
T
>
(
0
));
}
}
template
<
typename
T
,
typename
OP
>
__global__
__launch_bounds__
(
1024
)
void
ReduceFirstOrSecondAxisKernel
(
const
T
*
in
,
T
*
out
,
const
int
num_planes
,
const
int
num_rows
,
const
int
num_cols
,
OP
op
,
T
init
)
{
const
int
gid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
const
int
elems_per_plane
=
num_rows
*
num_cols
;
const
int
plane
=
gid
/
num_cols
;
const
int
col
=
gid
%
num_cols
;
if
(
plane
>=
num_planes
)
return
;
if
(
num_rows
==
1
)
{
out
[
plane
*
elems_per_plane
+
col
]
=
in
[
plane
*
elems_per_plane
+
col
];
return
;
}
T
sum
=
op
(
in
[
plane
*
elems_per_plane
+
col
],
in
[
plane
*
elems_per_plane
+
num_cols
+
col
]);
for
(
int
row
=
2
;
row
<
num_rows
;
++
row
)
{
sum
=
op
(
sum
,
in
[
plane
*
elems_per_plane
+
row
*
num_cols
+
col
]);
}
out
[
plane
*
num_cols
+
col
]
=
sum
;
}
template
<
typename
DeviceContext
,
typename
T
>
static
void
ElemwiseYGradRank2CUDA
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
&
dout
,
const
int
planes
,
const
int
rows
,
const
int
cols
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
int
num_threads
=
128
;
int
num_blocks
=
(
rows
+
num_threads
-
1
)
/
num_threads
;
if
(
planes
!=
1
)
{
num_blocks
=
(
planes
*
cols
+
num_threads
-
1
)
/
num_threads
;
}
if
(
dx
)
{
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
framework
::
TensorCopy
(
dout
,
ctx
.
GetPlace
(),
ctx
.
template
device_context
<
platform
::
DeviceContext
>(),
dx
);
}
if
(
dy
)
{
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
T
*
dout_data
=
dout
.
data
<
T
>
();
T
*
dy_data
=
dy
->
data
<
T
>
();
auto
stream
=
ctx
.
template
device_context
<
DeviceContext
>().
stream
();
ReduceFirstOrSecondAxisKernel
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
dout_data
,
dy_data
,
planes
,
rows
,
cols
,
AddFunctor
<
T
>
(),
static_cast
<
T
>
(
0
));
}
}
template
<
typename
DeviceContext
,
typename
T
>
static
bool
ElemwiseGradUseReduce
(
const
framework
::
ExecutionContext
&
ctx
,
const
int
axis
,
const
framework
::
DDim
x_dims
,
const
framework
::
DDim
y_dims
,
const
framework
::
Tensor
&
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
int
start
=
0
;
int
end
=
0
;
auto
x_dims_vec
=
framework
::
vectorize
(
x_dims
);
if
(
UseReduceFirstAxisRank1
(
dout
.
dims
(),
x_dims
,
y_dims
,
axis
))
{
int
rows
=
std
::
accumulate
(
x_dims_vec
.
begin
(),
x_dims_vec
.
end
()
-
1
,
1
,
std
::
multiplies
<
int
>
());
int
cols
=
dx
->
dims
()[
dx
->
dims
().
size
()
-
1
];
if
(
cols
>
512
&&
cols
<
4096
)
{
ElemwiseYGradRank1CUDA
<
DeviceContext
,
T
>
(
ctx
,
dout
,
rows
,
cols
,
dx
,
dy
);
return
true
;
}
}
if
(
UseReduceFirstAxisRank2
(
dout
.
dims
(),
x_dims
,
y_dims
,
axis
))
{
int
rows
=
std
::
accumulate
(
x_dims_vec
.
begin
(),
x_dims_vec
.
end
()
-
2
,
1
,
std
::
multiplies
<
int
>
());
int
cols
=
dx
->
dims
()[
dx
->
dims
().
size
()
-
1
]
*
dx
->
dims
()[
dx
->
dims
().
size
()
-
2
];
if
(
cols
>
4096
)
{
ElemwiseYGradRank2CUDA
<
DeviceContext
,
T
>
(
ctx
,
dout
,
1
,
rows
,
cols
,
dx
,
dy
);
return
true
;
}
}
if
(
UseReduceSecondAxisRank2
(
dout
.
dims
(),
x_dims
,
y_dims
,
axis
,
&
start
,
&
end
))
{
int
planes
=
std
::
accumulate
(
x_dims_vec
.
begin
(),
x_dims_vec
.
begin
()
+
start
,
1
,
std
::
multiplies
<
int
>
());
int
rows
=
std
::
accumulate
(
x_dims_vec
.
begin
()
+
start
,
x_dims_vec
.
begin
()
+
end
+
1
,
1
,
std
::
multiplies
<
int
>
());
int
cols
=
std
::
accumulate
(
x_dims_vec
.
begin
()
+
end
+
1
,
x_dims_vec
.
end
(),
1
,
std
::
multiplies
<
int
>
());
if
(
rows
/
(
planes
*
cols
)
<
16
)
{
ElemwiseYGradRank2CUDA
<
DeviceContext
,
T
>
(
ctx
,
dout
,
planes
,
rows
,
cols
,
dx
,
dy
);
return
true
;
}
}
return
false
;
}
template
<
typename
T
>
class
ElementwiseAddGradKernel
<
platform
::
CUDADeviceContext
,
T
>
:
public
ElemwiseGradKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
ElemwiseGradKernel
<
T
>::
Compute
(
ctx
);
using
Tensor
=
framework
::
Tensor
;
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
y
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
auto
*
dout
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
// skip out
auto
*
out
=
dout
;
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
// Special case when dy is not needed and dx doesn't reduce
if
(
dx
!=
nullptr
&&
dy
==
nullptr
&&
dx
->
dims
()
==
dout
->
dims
())
{
VLOG
(
4
)
<<
"Special case when dy is not needed and dx doesn't "
"reduce"
;
framework
::
TensorCopy
(
*
dout
,
ctx
.
GetPlace
(),
ctx
.
template
device_context
<
platform
::
DeviceContext
>(),
dx
);
}
else
if
(
dx
==
nullptr
&&
dy
!=
nullptr
&&
dy
->
dims
()
==
dout
->
dims
())
{
VLOG
(
4
)
<<
"Special case when dx is not needed and dy doesn't "
"reduce"
;
framework
::
TensorCopy
(
*
dout
,
ctx
.
GetPlace
(),
ctx
.
template
device_context
<
platform
::
DeviceContext
>(),
dy
);
}
else
if
(
dx
&&
dy
&&
(
dx
->
dims
()
==
dy
->
dims
()))
{
ElementwiseAddGrad
<
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
}
else
if
(
dx
&&
dx
->
dims
()
==
dout
->
dims
()
&&
ElemwiseGradUseReduce
<
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
axis
,
x
->
dims
(),
y
->
dims
(),
*
dout
,
dx
,
dy
))
{
}
else
if
(
dy
&&
dy
->
dims
()
==
dout
->
dims
()
&&
ElemwiseGradUseReduce
<
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
axis
,
x
->
dims
(),
y
->
dims
(),
*
dout
,
dy
,
dx
))
{
}
else
{
DefaultElementwiseAddGrad
<
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
}
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
...
...
paddle/fluid/operators/elementwise/elementwise_add_op.h
浏览文件 @
befd6d53
...
@@ -22,9 +22,10 @@ namespace paddle {
...
@@ -22,9 +22,10 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
void
default_elementwise_add
(
const
framework
::
ExecutionContext
&
ctx
,
void
DefaultElementwiseAddGrad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
x_dims
=
x
->
dims
();
auto
x_dims
=
x
->
dims
();
auto
y_dims
=
y
->
dims
();
auto
y_dims
=
y
->
dims
();
...
@@ -57,7 +58,7 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
...
@@ -57,7 +58,7 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
SameDimsElemwiseAdd
<
DeviceContext
,
T
>
same_dims_add
;
SameDimsElemwiseAdd
<
DeviceContext
,
T
>
same_dims_add
;
same_dims_add
(
ctx
,
x
,
y
,
z
);
same_dims_add
(
ctx
,
x
,
y
,
z
);
}
else
{
}
else
{
default_elementwise_ad
d
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
z
);
DefaultElementwiseAddGra
d
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
z
);
}
}
}
}
};
};
...
@@ -68,13 +69,12 @@ struct IdentityGrad {
...
@@ -68,13 +69,12 @@ struct IdentityGrad {
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
void
default_elementwise_add_grad
(
const
framework
::
ExecutionContext
&
ctx
,
void
DefaultElementwiseAddGrad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
framework
::
Tensor
*
dy
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
ElemwiseExplicitGradCompute
<
DeviceContext
,
T
,
IdentityGrad
<
T
>
,
ElemwiseExplicitGradCompute
<
DeviceContext
,
T
,
IdentityGrad
<
T
>
,
...
@@ -87,11 +87,10 @@ template <typename DeviceContext, typename T>
...
@@ -87,11 +87,10 @@ template <typename DeviceContext, typename T>
typename
std
::
enable_if
<
typename
std
::
enable_if
<
std
::
is_floating_point
<
T
>::
value
&&
std
::
is_floating_point
<
T
>::
value
&&
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
>::
type
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
>::
type
elementwise_add_grad
(
const
framework
::
ExecutionContext
&
ctx
,
ElementwiseAddGrad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
framework
::
Tensor
*
dy
)
{
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
if
(
dx
)
{
if
(
dx
)
{
blas
.
VCOPY
(
dout
->
numel
(),
dout
->
data
<
T
>
(),
blas
.
VCOPY
(
dout
->
numel
(),
dout
->
data
<
T
>
(),
...
@@ -108,12 +107,11 @@ template <typename DeviceContext, typename T>
...
@@ -108,12 +107,11 @@ template <typename DeviceContext, typename T>
typename
std
::
enable_if
<
typename
std
::
enable_if
<
!
std
::
is_floating_point
<
T
>::
value
&&
!
std
::
is_floating_point
<
T
>::
value
&&
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
>::
type
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
>::
type
elementwise_add_grad
(
const
framework
::
ExecutionContext
&
ctx
,
ElementwiseAddGrad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
framework
::
Tensor
*
dy
)
{
DefaultElementwiseAddGrad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
default_elementwise_add_grad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
}
}
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
...
@@ -121,11 +119,10 @@ elementwise_add_grad(const framework::ExecutionContext &ctx,
...
@@ -121,11 +119,10 @@ elementwise_add_grad(const framework::ExecutionContext &ctx,
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CUDADeviceContext
>::
value
>::
type
std
::
is_same
<
DeviceContext
,
platform
::
CUDADeviceContext
>::
value
>::
type
elementwise_add_grad
(
const
framework
::
ExecutionContext
&
ctx
,
ElementwiseAddGrad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
);
framework
::
Tensor
*
dy
);
#endif
#endif
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
...
@@ -158,10 +155,9 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
...
@@ -158,10 +155,9 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
*
dout
,
ctx
.
GetPlace
(),
*
dout
,
ctx
.
GetPlace
(),
ctx
.
template
device_context
<
platform
::
DeviceContext
>(),
dy
);
ctx
.
template
device_context
<
platform
::
DeviceContext
>(),
dy
);
}
else
if
(
dx
!=
nullptr
&&
dy
!=
nullptr
&&
(
dx
->
dims
()
==
dy
->
dims
()))
{
}
else
if
(
dx
!=
nullptr
&&
dy
!=
nullptr
&&
(
dx
->
dims
()
==
dy
->
dims
()))
{
elementwise_add_g
rad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
ElementwiseAddG
rad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
}
else
{
}
else
{
default_elementwise_add_grad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
DefaultElementwiseAddGrad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
dy
);
}
}
}
}
};
};
...
@@ -186,8 +182,8 @@ class ElementwiseAddDoubleGradKernel : public framework::OpKernel<T> {
...
@@ -186,8 +182,8 @@ class ElementwiseAddDoubleGradKernel : public framework::OpKernel<T> {
GetDoubleGradSafeTensor
<
DeviceContext
,
T
>
(
ctx
,
y
,
ddy
,
&
ddy_safe
);
GetDoubleGradSafeTensor
<
DeviceContext
,
T
>
(
ctx
,
y
,
ddy
,
&
ddy_safe
);
ddout
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ddout
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
default_elementwise_ad
d
<
DeviceContext
,
T
>
(
ctx
,
&
ddx_safe
,
&
ddy_safe
,
DefaultElementwiseAddGra
d
<
DeviceContext
,
T
>
(
ctx
,
&
ddx_safe
,
&
ddy_safe
,
ddout
);
ddout
);
}
}
}
}
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录