Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
95da78a6
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
95da78a6
编写于
12月 27, 2017
作者:
Q
qingqing01
提交者:
GitHub
12月 27, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #7047 from qingqing01/rowwise_add
Optimize the rowwise add function.
上级
0d5de244
19367389
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
50 addition
and
19 deletion
+50
-19
paddle/operators/math/math_function.cc
paddle/operators/math/math_function.cc
+21
-0
paddle/operators/math/math_function.cu
paddle/operators/math/math_function.cu
+29
-0
paddle/operators/math/math_function_impl.h
paddle/operators/math/math_function_impl.h
+0
-19
未找到文件。
paddle/operators/math/math_function.cc
浏览文件 @
95da78a6
...
...
@@ -302,8 +302,29 @@ void set_constant(const platform::DeviceContext& context,
#endif
}
template
<
typename
T
>
struct
RowwiseAdd
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
vector
,
framework
::
Tensor
*
output
)
{
auto
in_dims
=
input
.
dims
();
auto
size
=
input
.
numel
()
/
in_dims
[
0
];
PADDLE_ENFORCE_EQ
(
vector
.
numel
(),
size
);
PADDLE_ENFORCE_EQ
(
output
->
dims
(),
in_dims
);
auto
in
=
framework
::
EigenMatrix
<
T
>::
From
(
input
);
auto
vec
=
framework
::
EigenVector
<
T
>::
Flatten
(
vector
);
auto
out
=
framework
::
EigenMatrix
<
T
>::
From
(
*
output
);
for
(
int64_t
i
=
0
;
i
<
in_dims
[
0
];
++
i
)
{
out
.
chip
(
i
,
0
)
=
in
.
chip
(
i
,
0
)
+
vec
;
}
}
};
template
struct
RowwiseAdd
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
RowwiseAdd
<
platform
::
CPUDeviceContext
,
double
>;
template
struct
ColwiseSum
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
ColwiseSum
<
platform
::
CPUDeviceContext
,
double
>;
...
...
paddle/operators/math/math_function.cu
浏览文件 @
95da78a6
...
...
@@ -273,6 +273,35 @@ void set_constant_with_place<platform::CUDAPlace>(
TensorSetConstantGPU
(
context
,
tensor
,
value
));
}
template
<
typename
T
>
__global__
void
RowwiseAddKernel
(
const
T
*
a
,
const
T
*
b
,
T
*
c
,
int
width
,
int
num
)
{
T
tmp
=
1.0
/
width
;
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
num
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
h
=
i
*
tmp
;
int
w
=
i
-
h
*
width
;
c
[
i
]
=
a
[
i
]
+
b
[
w
];
}
}
template
<
typename
T
>
struct
RowwiseAdd
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
vector
,
framework
::
Tensor
*
output
)
{
auto
in_dims
=
input
.
dims
();
auto
size
=
input
.
numel
()
/
in_dims
[
0
];
PADDLE_ENFORCE_EQ
(
vector
.
numel
(),
size
);
PADDLE_ENFORCE_EQ
(
output
->
dims
(),
in_dims
);
int
blocks
=
512
;
int
grids
=
(
input
.
numel
()
+
blocks
-
1
)
/
blocks
;
RowwiseAddKernel
<
T
><<<
grids
,
blocks
,
0
,
context
.
stream
()
>>>
(
input
.
data
<
T
>
(),
vector
.
data
<
T
>
(),
output
->
data
<
T
>
(),
static_cast
<
int
>
(
in_dims
[
1
]),
static_cast
<
int
>
(
input
.
numel
()));
}
};
template
struct
RowwiseAdd
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
RowwiseAdd
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
ColwiseSum
<
platform
::
CUDADeviceContext
,
float
>;
...
...
paddle/operators/math/math_function_impl.h
浏览文件 @
95da78a6
...
...
@@ -45,25 +45,6 @@ void Transpose<DeviceContext, T, Rank>::operator()(
eigen_out
.
device
(
*
dev
)
=
eigen_in
.
shuffle
(
permute
);
}
template
<
typename
DeviceContext
,
typename
T
>
void
RowwiseAdd
<
DeviceContext
,
T
>::
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
vector
,
framework
::
Tensor
*
output
)
{
auto
in_dims
=
input
.
dims
();
auto
size
=
input
.
numel
()
/
in_dims
[
0
];
PADDLE_ENFORCE_EQ
(
vector
.
numel
(),
size
);
PADDLE_ENFORCE_EQ
(
output
->
dims
(),
in_dims
);
auto
in
=
framework
::
EigenMatrix
<
T
>::
From
(
input
);
auto
vec
=
framework
::
EigenMatrix
<
T
>::
From
(
vector
);
auto
out
=
framework
::
EigenMatrix
<
T
>::
From
(
*
output
);
Eigen
::
array
<
int
,
2
>
shape
({{
1
,
static_cast
<
int
>
(
size
)}});
Eigen
::
array
<
int
,
2
>
bcast
({{
static_cast
<
int
>
(
in_dims
[
0
]),
1
}});
out
.
device
(
*
context
.
eigen_device
())
=
in
+
vec
.
reshape
(
shape
).
broadcast
(
bcast
);
}
template
<
typename
DeviceContext
,
typename
T
>
void
ColwiseSum
<
DeviceContext
,
T
>::
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录