Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
26736576
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
26736576
编写于
11月 14, 2017
作者:
D
dangqingqing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move RowwiseAdd functor to math_funcion and Add ColwiseSum functor.
上级
48947b51
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
66 addition
and
80 deletion
+66
-80
paddle/operators/gru_op.h
paddle/operators/gru_op.h
+2
-8
paddle/operators/lstm_op.h
paddle/operators/lstm_op.h
+5
-10
paddle/operators/math/math_function.cc
paddle/operators/math/math_function.cc
+5
-0
paddle/operators/math/math_function.cu
paddle/operators/math/math_function.cu
+5
-0
paddle/operators/math/math_function.h
paddle/operators/math/math_function.h
+13
-0
paddle/operators/math/math_function_impl.h
paddle/operators/math/math_function_impl.h
+36
-1
paddle/operators/math/sequence2batch.cc
paddle/operators/math/sequence2batch.cc
+0
-23
paddle/operators/math/sequence2batch.cu
paddle/operators/math/sequence2batch.cu
+0
-31
paddle/operators/math/sequence2batch.h
paddle/operators/math/sequence2batch.h
+0
-7
未找到文件。
paddle/operators/gru_op.h
浏览文件 @
26736576
...
...
@@ -205,14 +205,8 @@ class GRUGradKernel : public framework::OpKernel<T> {
}
if
(
bias_grad
)
{
bias_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
m
=
static_cast
<
int
>
(
batch_gate_grad
.
dims
()[
0
]);
int
n
=
static_cast
<
int
>
(
batch_gate_grad
.
dims
()[
1
]);
Tensor
ones
;
ones
.
mutable_data
<
T
>
({
m
},
context
.
GetPlace
());
math
::
SetConstant
<
Place
,
T
>
set
;
set
(
dev_ctx
,
&
ones
,
static_cast
<
T
>
(
1
));
math
::
gemv
<
Place
,
T
>
(
dev_ctx
,
true
,
m
,
n
,
1.
,
batch_gate_grad
.
data
<
T
>
(),
ones
.
data
<
T
>
(),
0.
,
bias_grad
->
data
<
T
>
());
math
::
ColwiseSum
<
Place
,
T
>
col_sum
;
col_sum
(
dev_ctx
,
batch_gate_grad
,
bias_grad
);
}
}
...
...
paddle/operators/lstm_op.h
浏览文件 @
26736576
...
...
@@ -341,16 +341,11 @@ class LSTMGradKernel : public framework::OpKernel<T> {
}
if
(
bias
&&
bias_g
)
{
/* backward bias */
int
m
=
static_cast
<
int
>
(
batch_gate_g
.
dims
()[
0
]);
int
n
=
static_cast
<
int
>
(
batch_gate_g
.
dims
()[
1
]);
Tensor
ones
;
ones
.
mutable_data
<
T
>
({
m
},
ctx
.
GetPlace
());
math
::
SetConstant
<
Place
,
T
>
set
;
set
(
device_ctx
,
&
ones
,
static_cast
<
T
>
(
1.0
));
math
::
gemv
<
Place
,
T
>
(
device_ctx
,
true
,
m
,
n
,
1.
,
batch_gate_g
.
data
<
T
>
(),
ones
.
data
<
T
>
(),
0.
,
bias_g
->
data
<
T
>
());
Tensor
b_g
=
*
bias_g
;
b_g
.
Resize
({
bias_g
->
numel
(),
1
});
Tensor
gate_bias_g
=
b_g
.
Slice
(
0
,
4
*
frame_size
);
math
::
ColwiseSum
<
Place
,
T
>
col_sum
;
col_sum
(
device_ctx
,
batch_gate_g
,
&
gate_bias_g
);
}
if
(
h0
&&
h0_g
)
{
...
...
paddle/operators/math/math_function.cc
浏览文件 @
26736576
...
...
@@ -308,6 +308,11 @@ void set_constant(const platform::DeviceContext& context,
#endif
}
template
struct
RowwiseAdd
<
platform
::
CPUPlace
,
float
>;
template
struct
RowwiseAdd
<
platform
::
CPUPlace
,
double
>;
template
struct
ColwiseSum
<
platform
::
CPUPlace
,
float
>;
template
struct
ColwiseSum
<
platform
::
CPUPlace
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/math_function.cu
浏览文件 @
26736576
...
...
@@ -292,6 +292,11 @@ void set_constant_with_place<platform::GPUPlace>(
TensorSetConstantGPU
(
context
,
tensor
,
value
));
}
template
struct
RowwiseAdd
<
platform
::
GPUPlace
,
float
>;
template
struct
RowwiseAdd
<
platform
::
GPUPlace
,
double
>;
template
struct
ColwiseSum
<
platform
::
GPUPlace
,
float
>;
template
struct
ColwiseSum
<
platform
::
GPUPlace
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/math_function.h
浏览文件 @
26736576
...
...
@@ -117,6 +117,19 @@ void set_constant_with_place(const platform::DeviceContext& context,
void
set_constant
(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
*
tensor
,
float
value
);
template
<
typename
Place
,
typename
T
>
struct
RowwiseAdd
{
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
vec
,
framework
::
Tensor
*
output
);
};
template
<
typename
Place
,
typename
T
>
struct
ColwiseSum
{
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
*
vec
);
};
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/math_function_impl.h
浏览文件 @
26736576
...
...
@@ -43,6 +43,41 @@ void Transpose<Place, T, Rank>::operator()(
auto
*
dev
=
context
.
GetEigenDevice
<
Place
>
();
eigen_out
.
device
(
*
dev
)
=
eigen_in
.
shuffle
(
permute
);
}
template
<
typename
Place
,
typename
T
>
void
RowwiseAdd
<
Place
,
T
>::
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
vector
,
framework
::
Tensor
*
output
)
{
auto
in_dims
=
input
.
dims
();
auto
size
=
input
.
numel
()
/
in_dims
[
0
];
PADDLE_ENFORCE_EQ
(
vector
.
numel
(),
size
);
PADDLE_ENFORCE_EQ
(
output
->
dims
(),
in_dims
);
auto
in
=
framework
::
EigenMatrix
<
T
>::
From
(
input
);
auto
vec
=
framework
::
EigenMatrix
<
T
>::
From
(
vector
);
auto
out
=
framework
::
EigenMatrix
<
T
>::
From
(
*
output
);
Eigen
::
array
<
int
,
2
>
shape
({{
1
,
static_cast
<
int
>
(
size
)}});
Eigen
::
array
<
int
,
2
>
bcast
({{
static_cast
<
int
>
(
in_dims
[
0
]),
1
}});
out
.
device
(
*
context
.
GetEigenDevice
<
Place
>
())
=
in
+
vec
.
reshape
(
shape
).
broadcast
(
bcast
);
}
template
<
typename
Place
,
typename
T
>
void
ColwiseSum
<
Place
,
T
>::
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
*
vector
)
{
auto
in_dims
=
input
.
dims
();
auto
size
=
input
.
numel
()
/
in_dims
[
0
];
PADDLE_ENFORCE_EQ
(
vector
->
numel
(),
size
);
auto
vec
=
framework
::
EigenMatrix
<
T
>::
From
(
*
vector
);
auto
in
=
framework
::
EigenMatrix
<
T
>::
From
(
input
);
Eigen
::
array
<
int
,
2
>
shape
({{
1
,
static_cast
<
int
>
(
size
)}});
vec
.
reshape
(
shape
).
device
(
*
context
.
GetEigenDevice
<
Place
>
())
=
in
.
sum
(
Eigen
::
array
<
int
,
1
>
({{
0
}})).
reshape
(
shape
);
}
}
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/sequence2batch.cc
浏览文件 @
26736576
...
...
@@ -56,29 +56,6 @@ template class LoDTensor2BatchFunctor<platform::CPUPlace, double>;
template
class
Batch2LoDTensorFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
Batch2LoDTensorFunctor
<
platform
::
CPUPlace
,
double
>;
template
<
typename
T
>
struct
RowwiseAdd
<
platform
::
CPUPlace
,
T
>
{
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
bias
,
framework
::
Tensor
*
output
)
{
auto
in_dims
=
input
.
dims
();
auto
size
=
input
.
numel
()
/
in_dims
[
0
];
PADDLE_ENFORCE_EQ
(
bias
.
numel
(),
size
);
PADDLE_ENFORCE_EQ
(
output
->
dims
(),
in_dims
);
auto
in
=
EigenMatrix
<
T
>::
From
(
input
);
auto
b
=
EigenMatrix
<
T
>::
From
(
bias
);
auto
out
=
EigenMatrix
<
T
>::
From
(
*
output
);
Eigen
::
array
<
int
,
2
>
bshape
({{
1
,
static_cast
<
int
>
(
size
)}});
Eigen
::
array
<
int
,
2
>
bcast
({{
static_cast
<
int
>
(
in_dims
[
0
]),
1
}});
out
.
device
(
*
context
.
GetEigenDevice
<
platform
::
CPUPlace
>
())
=
in
+
b
.
reshape
(
bshape
).
broadcast
(
bcast
);
}
};
template
struct
RowwiseAdd
<
platform
::
CPUPlace
,
float
>;
template
struct
RowwiseAdd
<
platform
::
CPUPlace
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/sequence2batch.cu
浏览文件 @
26736576
...
...
@@ -74,37 +74,6 @@ template class LoDTensor2BatchFunctor<platform::GPUPlace, double>;
template
class
Batch2LoDTensorFunctor
<
platform
::
GPUPlace
,
float
>;
template
class
Batch2LoDTensorFunctor
<
platform
::
GPUPlace
,
double
>;
template
<
typename
T
>
__global__
void
RowwiseAddKernel
(
const
T
*
src
,
const
T
*
b
,
T
*
dst
,
int64_t
height
,
int64_t
width
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
height
*
width
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
int64_t
h
=
i
/
width
;
int64_t
w
=
i
%
width
;
dst
[
h
*
width
+
w
]
=
src
[
h
*
width
+
w
]
+
b
[
w
];
}
}
template
<
typename
T
>
struct
RowwiseAdd
<
platform
::
GPUPlace
,
T
>
{
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
bias
,
framework
::
Tensor
*
output
)
{
auto
in_dims
=
input
.
dims
();
auto
size
=
input
.
numel
()
/
in_dims
[
0
];
PADDLE_ENFORCE_EQ
(
bias
.
numel
(),
size
);
PADDLE_ENFORCE_EQ
(
output
->
dims
(),
in_dims
);
int
block
=
512
;
int
grid
=
(
input
.
numel
()
+
block
-
1
)
/
block
;
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
RowwiseAddKernel
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
T
>
(),
bias
.
data
<
T
>
(),
output
->
data
<
T
>
(),
in_dims
[
0
],
size
);
}
};
template
struct
RowwiseAdd
<
platform
::
GPUPlace
,
float
>;
template
struct
RowwiseAdd
<
platform
::
GPUPlace
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/sequence2batch.h
浏览文件 @
26736576
...
...
@@ -164,13 +164,6 @@ class Batch2LoDTensorFunctor {
}
};
template
<
typename
Place
,
typename
T
>
struct
RowwiseAdd
{
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
bias
,
framework
::
Tensor
*
output
);
};
}
// namespace math
}
// namespace operators
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录