Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
88c22e9d
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
88c22e9d
编写于
2月 23, 2018
作者:
Y
Yu Yang
提交者:
Yang Yang(Tony)
2月 22, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Speed up elemwise grad (#8402)
* Speed up elemwise grad * Fix bug * Add macro for MAX_BLOCK_DIM
上级
d3162339
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
259 addition
and
57 deletion
+259
-57
paddle/fluid/operators/elementwise_add_op.h
paddle/fluid/operators/elementwise_add_op.h
+5
-57
paddle/fluid/operators/elementwise_op_function.h
paddle/fluid/operators/elementwise_op_function.h
+254
-0
未找到文件。
paddle/fluid/operators/elementwise_add_op.h
浏览文件 @
88c22e9d
...
...
@@ -41,59 +41,8 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
};
template
<
typename
T
>
struct
ElementwiseAddGradFunctor
{
template
<
typename
Device
,
typename
X
,
typename
Y
,
typename
Z
,
typename
dX
,
typename
dY
,
typename
dZ
>
void
operator
()(
Device
d
,
X
x
,
Y
y
,
Z
z
,
dX
dx
,
dY
dy
,
dZ
dz
)
{
auto
dz_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dz
);
if
(
dx
)
{
auto
dx_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dx
);
dx_e
.
device
(
d
)
=
dz_e
;
}
if
(
dy
)
{
auto
dy_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dy
);
dy_e
.
device
(
d
)
=
dz_e
;
}
}
};
template
<
typename
T
>
struct
ElementwiseAddBroadCastGradFunctor
{
template
<
typename
Device
,
typename
X
,
typename
Y
,
typename
Z
,
typename
dX
,
typename
dY
,
typename
dZ
,
typename
Pre
,
typename
N
>
void
operator
()(
Device
d
,
X
x
,
Y
y
,
Z
z
,
dX
dx
,
dY
dy
,
dZ
dz
,
Pre
pre
,
N
n
)
{
auto
dz_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dz
);
if
(
dx
)
{
auto
dx_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dx
);
dx_e
.
device
(
d
)
=
dz_e
;
}
if
(
dy
)
{
auto
dy_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dy
);
dy_e
.
device
(
d
)
=
dz_e
.
reshape
(
Eigen
::
DSizes
<
int
,
2
>
(
pre
,
n
))
.
sum
(
Eigen
::
array
<
int
,
1
>
{{
0
}});
}
}
};
template
<
typename
T
>
struct
ElementwiseAddBroadCast2GradFunctor
{
template
<
typename
Device
,
typename
X
,
typename
Y
,
typename
Z
,
typename
dX
,
typename
dY
,
typename
dZ
,
typename
Pre
,
typename
N
,
typename
Post
>
void
operator
()(
Device
d
,
X
x
,
Y
y
,
Z
z
,
dX
dx
,
dY
dy
,
dZ
dz
,
Pre
pre
,
N
n
,
Post
post
)
{
auto
dz_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dz
);
if
(
dx
)
{
auto
dx_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dx
);
dx_e
.
device
(
d
)
=
dz_e
;
}
if
(
dy
)
{
auto
dy_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dy
);
dy_e
.
device
(
d
)
=
dz_e
.
reshape
(
Eigen
::
DSizes
<
int
,
3
>
(
pre
,
n
,
post
))
.
sum
(
Eigen
::
array
<
int
,
2
>
{{
0
,
2
}});
}
}
struct
IdentityGrad
{
HOSTDEVICE
T
operator
()(
T
x
,
T
y
,
T
out
,
T
dout
)
const
{
return
dout
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
...
...
@@ -109,10 +58,9 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
ElementwiseGradCompute
<
DeviceContext
,
T
,
ElementwiseAddGradFunctor
<
T
>
,
ElementwiseAddBroadCastGradFunctor
<
T
>
,
ElementwiseAddBroadCast2GradFunctor
<
T
>>
(
ctx
,
x
,
y
,
out
,
dout
,
axis
,
dx
,
dy
);
ElemwiseGradCompute
<
DeviceContext
,
T
,
IdentityGrad
<
T
>
,
IdentityGrad
<
T
>>
(
ctx
,
*
x
,
*
y
,
*
out
,
*
dout
,
axis
,
dx
,
dy
,
IdentityGrad
<
T
>
(),
IdentityGrad
<
T
>
());
}
};
...
...
paddle/fluid/operators/elementwise_op_function.h
浏览文件 @
88c22e9d
...
...
@@ -20,9 +20,11 @@ limitations under the License. */
#ifdef __NVCC__
#include <thrust/iterator/iterator_adaptor.h>
constexpr
int
ELEMWISE_MAX_BLOCK_DIM
=
1024
;
#endif
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -311,6 +313,258 @@ EIGEN_FUNCTOR(Mul, EIGEN_MUL);
#define EIGEN_DIV(x, y) ((x) / (y))
EIGEN_FUNCTOR
(
Div
,
EIGEN_DIV
);
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
struct
ElemwiseGradNoBroadcast
{
const
T
*
x_
;
const
T
*
y_
;
const
T
*
out_
;
const
T
*
dout_
;
HOSTDEVICE
void
operator
()(
size_t
i
)
{
if
(
dx_
!=
nullptr
)
{
dx_
[
i
]
=
dx_op_
(
x_
[
i
],
y_
[
i
],
out_
[
i
],
dout_
[
i
]);
}
if
(
dy_
!=
nullptr
)
{
dy_
[
i
]
=
dx_op_
(
x_
[
i
],
y_
[
i
],
out_
[
i
],
dout_
[
i
]);
}
}
DX_OP
dx_op_
;
DY_OP
dy_op_
;
T
*
dx_
;
T
*
dy_
;
};
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
void
ElemwiseGradBroadcast1CPU
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
for
(
int
i
=
0
;
i
<
h
;
++
i
)
{
for
(
int
j
=
0
;
j
<
w
;
++
j
)
{
int
x_offset
=
i
*
w
+
j
;
if
(
dx
!=
nullptr
)
{
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
}
if
(
dy
!=
nullptr
)
{
T
tmp
=
dy_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
if
(
i
==
0
)
{
dy
[
j
]
=
tmp
;
}
else
{
dy
[
j
]
+=
tmp
;
}
}
}
}
}
#ifdef __NVCC__
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
__global__
void
ElemwiseGradBroadcast1CUDAKernel
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
extern
__shared__
char
shm_buffer
[];
T
*
shm
=
reinterpret_cast
<
T
*>
(
shm_buffer
);
int
j
=
blockIdx
.
x
;
int
i
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
shm
[
tid
]
=
0
;
do
{
int
x_offset
=
i
*
w
+
j
;
if
(
dx
)
{
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
}
if
(
dy
)
{
shm
[
tid
]
+=
dy_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
}
i
+=
ELEMWISE_MAX_BLOCK_DIM
;
}
while
(
i
<
h
);
if
(
dy
)
{
__syncthreads
();
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
// Sum, could be optimized
if
(
threadIdx
.
x
==
0
)
{
for
(
int
k
=
1
;
k
<
h
;
++
k
)
{
shm
[
0
]
+=
shm
[
k
];
}
dy
[
j
]
=
shm
[
0
];
}
}
}
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
void
ElemwiseGradBroadcast1CUDA
(
cudaStream_t
stream
,
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
int
block_size
=
std
::
min
(
ELEMWISE_MAX_BLOCK_DIM
,
h
);
int
gird_size
=
w
;
int
shared_mem_size
=
block_size
*
sizeof
(
T
);
ElemwiseGradBroadcast1CUDAKernel
<<<
gird_size
,
block_size
,
shared_mem_size
,
stream
>>>
(
x
,
y
,
out
,
dout
,
h
,
w
,
dx_op
,
dy_op
,
dx
,
dy
);
}
#endif
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
void
ElemwiseGradBroadcast2CPU
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
pre
,
int
n
,
int
post
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
for
(
int
i
=
0
;
i
<
pre
;
++
i
)
{
for
(
int
j
=
0
;
j
<
n
;
++
j
)
{
for
(
int
k
=
0
;
k
<
post
;
++
k
)
{
int
x_offset
=
i
*
n
*
post
+
j
*
post
+
k
;
if
(
dx
!=
nullptr
)
{
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
}
if
(
dy
!=
nullptr
)
{
T
tmp
=
dy_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
if
(
i
==
0
&&
k
==
0
)
{
dy
[
j
]
=
tmp
;
}
else
{
dy
[
j
]
+=
tmp
;
}
}
}
}
}
}
#ifdef __NVCC__
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
__global__
void
ElemwiseGradBroadcast2CUDAKernel
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
pre
,
int
n
,
int
post
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
int
tid
=
threadIdx
.
x
;
int
j
=
blockIdx
.
x
;
extern
__shared__
char
shm_buffer
[];
T
*
shm
=
reinterpret_cast
<
T
*>
(
shm_buffer
);
shm
[
tid
]
=
0
;
int
ttid
=
tid
;
while
(
true
)
{
int
i
=
ttid
/
post
;
int
k
=
ttid
%
post
;
if
(
i
>=
pre
)
break
;
int
x_offset
=
i
*
n
*
post
+
j
*
post
+
k
;
if
(
dx
!=
nullptr
)
{
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
}
if
(
dy
!=
nullptr
)
{
shm
[
tid
]
+=
dy_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
}
ttid
+=
ELEMWISE_MAX_BLOCK_DIM
;
}
if
(
dy
)
{
__syncthreads
();
int
h
=
pre
*
post
;
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
// Sum, could be optimized
if
(
tid
==
0
)
{
for
(
int
i
=
1
;
i
<
h
;
++
i
)
{
shm
[
0
]
+=
shm
[
i
];
}
dy
[
j
]
=
shm
[
0
];
}
}
}
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
void
ElemwiseGradBroadcast2CUDA
(
cudaStream_t
stream
,
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
pre
,
int
n
,
int
post
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
int
block_size
=
std
::
min
(
ELEMWISE_MAX_BLOCK_DIM
,
pre
*
post
);
int
gird_size
=
n
;
int
shared_mem_size
=
block_size
*
sizeof
(
T
);
ElemwiseGradBroadcast2CUDAKernel
<<<
gird_size
,
block_size
,
shared_mem_size
,
stream
>>>
(
x
,
y
,
out
,
dout
,
pre
,
n
,
post
,
dx_op
,
dy_op
,
dx
,
dy
);
}
#endif
template
<
typename
DeviceContext
,
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
void
ElemwiseGradCompute
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
&
y
,
const
framework
::
Tensor
&
out
,
const
framework
::
Tensor
&
dout
,
int
axis
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
,
DX_OP
dx_op
,
DY_OP
dy_op
)
{
if
(
x
.
dims
()
==
y
.
dims
())
{
size_t
N
=
static_cast
<
size_t
>
(
framework
::
product
(
x
.
dims
()));
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx
.
template
device_context
<
DeviceContext
>(),
N
);
for_range
(
ElemwiseGradNoBroadcast
<
T
,
DX_OP
,
DY_OP
>
{
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
())});
}
else
{
// Y is a scalar
auto
x_dim
=
x
.
dims
();
auto
y_dim
=
y
.
dims
();
if
(
y_dim
.
size
()
==
1
&&
y_dim
[
0
]
==
1
)
{
// y is a scalar
auto
extended_dims
=
framework
::
vectorize
(
x_dim
);
extended_dims
.
push_back
(
1
);
x_dim
=
framework
::
make_ddim
(
extended_dims
);
}
axis
=
(
axis
==
-
1
?
x_dim
.
size
()
-
y_dim
.
size
()
:
axis
);
int
pre
,
n
,
post
;
get_mid_dims
(
x_dim
,
y_dim
,
axis
,
pre
,
n
,
post
);
if
(
post
==
1
)
{
int
h
=
pre
;
int
w
=
n
;
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef __NVCC__
ElemwiseGradBroadcast1CUDA
(
ctx
.
template
device_context
<
DeviceContext
>().
stream
(),
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
h
,
w
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
#endif
}
else
{
ElemwiseGradBroadcast1CPU
(
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
h
,
w
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
}
else
{
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef __NVCC__
ElemwiseGradBroadcast2CUDA
(
ctx
.
template
device_context
<
DeviceContext
>().
stream
(),
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
pre
,
n
,
post
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
#endif
}
else
{
ElemwiseGradBroadcast2CPU
(
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
pre
,
n
,
post
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
}
}
};
template
<
typename
DeviceContext
,
typename
T
,
typename
functor
,
typename
broadcastfunctor
,
typename
broadcast2functor
>
void
ElementwiseGradCompute
(
const
framework
::
ExecutionContext
&
ctx
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录