Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
7ad13fbf
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7ad13fbf
编写于
10月 18, 2017
作者:
Q
QI JUN
提交者:
GitHub
10月 18, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #4876 from QiJune/sgd_op_sparse_kernel
add sparse update kernel for sgd operator
上级
c93596d3
f9681459
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
224 addition
and
29 deletion
+224
-29
paddle/operators/sgd_op.cc
paddle/operators/sgd_op.cc
+36
-4
paddle/operators/sgd_op.cu
paddle/operators/sgd_op.cu
+60
-0
paddle/operators/sgd_op.h
paddle/operators/sgd_op.h
+35
-13
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+14
-1
python/paddle/v2/framework/tests/test_selected_rows.py
python/paddle/v2/framework/tests/test_selected_rows.py
+12
-11
python/paddle/v2/framework/tests/test_sgd_op.py
python/paddle/v2/framework/tests/test_sgd_op.py
+67
-0
未找到文件。
paddle/operators/sgd_op.cc
浏览文件 @
7ad13fbf
...
@@ -21,7 +21,7 @@ class SGDOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,7 @@ class SGDOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of SGDOp should not be null."
);
"Input(Param) of SGDOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
...
@@ -35,15 +35,15 @@ class SGDOp : public framework::OperatorWithKernel {
...
@@ -35,15 +35,15 @@ class SGDOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ
(
framework
::
product
(
lr_dims
),
1
,
PADDLE_ENFORCE_EQ
(
framework
::
product
(
lr_dims
),
1
,
"Learning rate should have 1 element"
);
"Learning rate should have 1 element"
);
auto
param_dim
=
ctx
->
GetInputDim
(
"Param"
);
auto
param_dim
=
ctx
->
GetInputDim
(
"Param"
);
PADDLE_ENFORCE_EQ
(
param_dim
,
ctx
->
GetInputDim
(
"Grad"
),
// TODO(qijun): check dimensions of Param and Grad at complie
"Two input of SGD Op's dimension must be same."
);
// and run time.
ctx
->
SetOutputDim
(
"ParamOut"
,
param_dim
);
ctx
->
SetOutputDim
(
"ParamOut"
,
param_dim
);
}
}
};
};
class
SGDOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
SGDOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
SGDOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
SGDOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"Param"
,
"Input parameter"
);
AddInput
(
"Param"
,
"Input parameter"
);
AddInput
(
"LearningRate"
,
"Learning rate of SGD"
);
AddInput
(
"LearningRate"
,
"Learning rate of SGD"
);
...
@@ -58,6 +58,38 @@ param_out = param - learning_rate * grad;
...
@@ -58,6 +58,38 @@ param_out = param - learning_rate * grad;
)DOC"
);
)DOC"
);
}
}
};
};
template
<
typename
T
>
struct
SparseSGDFunctor
<
platform
::
CPUPlace
,
T
>
{
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
,
const
framework
::
Tensor
&
learning_rate
,
framework
::
Tensor
*
output
)
{
auto
in_height
=
input
.
height
();
auto
out_dims
=
output
->
dims
();
PADDLE_ENFORCE_EQ
(
in_height
,
out_dims
[
0
]);
auto
&
in_value
=
input
.
value
();
auto
&
in_rows
=
input
.
rows
();
int64_t
in_row_numel
=
in_value
.
numel
()
/
in_rows
.
size
();
PADDLE_ENFORCE_EQ
(
in_row_numel
,
output
->
numel
()
/
in_height
);
auto
*
in_data
=
in_value
.
data
<
T
>
();
auto
*
out_data
=
output
->
data
<
T
>
();
auto
*
lr
=
learning_rate
.
data
<
T
>
();
for
(
size_t
i
=
0
;
i
<
in_rows
.
size
();
i
++
)
{
for
(
int64_t
j
=
0
;
j
<
in_row_numel
;
j
++
)
{
out_data
[
in_rows
[
i
]
*
in_row_numel
+
j
]
-=
lr
[
0
]
*
in_data
[
i
*
in_row_numel
+
j
];
}
}
}
};
template
struct
SparseSGDFunctor
<
platform
::
CPUPlace
,
float
>;
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
...
...
paddle/operators/sgd_op.cu
浏览文件 @
7ad13fbf
...
@@ -14,6 +14,66 @@
...
@@ -14,6 +14,66 @@
#define EIGEN_USE_GPU
#define EIGEN_USE_GPU
#include "paddle/operators/sgd_op.h"
#include "paddle/operators/sgd_op.h"
#include "paddle/platform/cuda_helper.h"
namespace
paddle
{
namespace
operators
{
namespace
{
template
<
typename
T
>
__global__
void
SparseSGDFunctorKernel
(
const
T
*
selected_rows
,
const
int64_t
*
rows
,
const
T
*
learning_rate
,
T
*
tensor_out
,
int64_t
row_numel
,
int
block_size
)
{
const
int
ty
=
blockIdx
.
y
;
int
tid
=
threadIdx
.
x
;
selected_rows
+=
ty
*
row_numel
;
tensor_out
+=
rows
[
ty
]
*
row_numel
;
for
(
int
index
=
tid
;
index
<
row_numel
;
index
+=
block_size
)
{
// Since index in rows of SelectedRows can be duplicate, we have to use
// Atomic Operation to avoid concurrent write error.
paddle
::
platform
::
CudaAtomicAdd
(
tensor_out
+
index
,
-
1.0
*
learning_rate
[
0
]
*
selected_rows
[
index
]);
}
}
}
// namespace
template
<
typename
T
>
struct
SparseSGDFunctor
<
platform
::
GPUPlace
,
T
>
{
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
,
const
framework
::
Tensor
&
learning_rate
,
framework
::
Tensor
*
output
)
{
auto
in_height
=
input
.
height
();
auto
out_dims
=
output
->
dims
();
PADDLE_ENFORCE_EQ
(
in_height
,
out_dims
[
0
]);
auto
&
in_value
=
input
.
value
();
auto
&
in_rows
=
input
.
rows
();
int64_t
in_row_numel
=
in_value
.
numel
()
/
in_rows
.
size
();
PADDLE_ENFORCE_EQ
(
in_row_numel
,
output
->
numel
()
/
in_height
);
auto
*
in_data
=
in_value
.
data
<
T
>
();
auto
*
out_data
=
output
->
data
<
T
>
();
int
block_size
=
256
;
dim3
threads
(
block_size
,
1
);
dim3
grid
(
1
,
in_rows
.
size
());
SparseSGDFunctorKernel
<
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
in_data
,
in_rows
.
data
(),
learning_rate
.
data
<
T
>
(),
out_data
,
in_row_numel
,
block_size
);
}
};
template
struct
SparseSGDFunctor
<
platform
::
GPUPlace
,
float
>;
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
sgd
,
REGISTER_OP_GPU_KERNEL
(
sgd
,
...
...
paddle/operators/sgd_op.h
浏览文件 @
7ad13fbf
...
@@ -15,31 +15,53 @@ limitations under the License. */
...
@@ -15,31 +15,53 @@ limitations under the License. */
#pragma once
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/selected_rows.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
Place
,
typename
T
>
struct
SparseSGDFunctor
{
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
,
const
framework
::
Tensor
&
learning_rate
,
framework
::
Tensor
*
output
);
};
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
SGDOpKernel
:
public
framework
::
OpKernel
<
T
>
{
class
SGDOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
auto
*
param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
auto
grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
);
auto
*
param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
*
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
auto
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
// Actually, all tensors are LoDTensor except SelectedRows.
if
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
())
{
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
);
auto
p
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param
);
auto
p
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param
);
auto
g
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
grad
);
auto
g
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
grad
);
auto
o
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param_out
);
auto
o
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param_out
);
auto
lr
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
learning_rate
);
auto
lr
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
learning_rate
);
auto
place
=
ctx
.
GetEigenDevice
<
Place
>
();
auto
place
=
ctx
.
GetEigenDevice
<
Place
>
();
Eigen
::
DSizes
<
int
,
1
>
grad_dsize
(
grad
->
numel
());
Eigen
::
DSizes
<
int
,
1
>
grad_dsize
(
grad
->
numel
());
o
.
device
(
place
)
=
p
-
lr
.
broadcast
(
grad_dsize
)
*
g
;
o
.
device
(
place
)
=
p
-
lr
.
broadcast
(
grad_dsize
)
*
g
;
}
else
if
(
grad_var
->
IsType
<
framework
::
SelectedRows
>
())
{
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
// This manual optimization brings difficulty to track data dependency.
// It's better to find a more elegant solution.
PADDLE_ENFORCE_EQ
(
param
,
param_out
);
auto
*
grad
=
ctx
.
Input
<
framework
::
SelectedRows
>
(
"Grad"
);
SparseSGDFunctor
<
Place
,
T
>
functor
;
functor
(
ctx
.
device_context
(),
*
grad
,
*
learning_rate
,
param_out
);
}
else
{
PADDLE_THROW
(
"Unsupported Variable Type of Grad"
);
}
}
}
};
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/pybind/pybind.cc
浏览文件 @
7ad13fbf
...
@@ -154,7 +154,15 @@ PYBIND11_PLUGIN(core) {
...
@@ -154,7 +154,15 @@ PYBIND11_PLUGIN(core) {
py
::
return_value_policy
::
reference
)
py
::
return_value_policy
::
reference
)
.
def
(
"set_height"
,
&
SelectedRows
::
set_height
)
.
def
(
"set_height"
,
&
SelectedRows
::
set_height
)
.
def
(
"height"
,
&
SelectedRows
::
height
)
.
def
(
"height"
,
&
SelectedRows
::
height
)
.
def
(
"set_rows"
,
&
SelectedRows
::
set_rows
)
.
def
(
"set_rows"
,
[](
SelectedRows
&
self
,
std
::
vector
<
int64_t
>
rows
)
{
#ifndef PADDLE_WITH_CUDA
self
.
set_rows
(
rows
);
#else
Vector
<
int64_t
>
new_rows
(
rows
);
self
.
set_rows
(
new_rows
);
#endif
})
.
def
(
"rows"
,
[](
SelectedRows
&
self
)
{
.
def
(
"rows"
,
[](
SelectedRows
&
self
)
{
#ifndef PADDLE_WITH_CUDA
#ifndef PADDLE_WITH_CUDA
return
self
.
rows
();
return
self
.
rows
();
...
@@ -187,6 +195,11 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -187,6 +195,11 @@ All parameter, weight, gradient are variables in Paddle.
return
self
.
GetMutable
<
LoDTensor
>
();
return
self
.
GetMutable
<
LoDTensor
>
();
},
},
py
::
return_value_policy
::
reference
)
py
::
return_value_policy
::
reference
)
.
def
(
"get_selected_rows"
,
[](
Variable
&
self
)
->
SelectedRows
*
{
return
self
.
GetMutable
<
SelectedRows
>
();
},
py
::
return_value_policy
::
reference
)
.
def
(
"get_net"
,
.
def
(
"get_net"
,
[](
Variable
&
self
)
->
operators
::
NetOp
*
{
[](
Variable
&
self
)
->
operators
::
NetOp
*
{
return
self
.
GetMutable
<
operators
::
NetOp
>
();
return
self
.
GetMutable
<
operators
::
NetOp
>
();
...
...
python/paddle/v2/framework/tests/test_selected_rows.py
浏览文件 @
7ad13fbf
...
@@ -8,29 +8,30 @@ class TestSelectedRows(unittest.TestCase):
...
@@ -8,29 +8,30 @@ class TestSelectedRows(unittest.TestCase):
place
=
core
.
CPUPlace
()
place
=
core
.
CPUPlace
()
height
=
10
height
=
10
rows
=
[
0
,
4
,
7
]
rows
=
[
0
,
4
,
7
]
row_numel
=
1
0
row_numel
=
1
2
sel
cted_rows
=
core
.
SelectedRows
(
rows
,
row_numel
)
sel
ected_rows
=
core
.
SelectedRows
(
rows
,
height
)
np_array
=
np
.
ones
((
len
(
rows
),
height
)).
astype
(
"float32"
)
np_array
=
np
.
ones
((
len
(
rows
),
row_numel
)).
astype
(
"float32"
)
np_array
[
0
,
0
]
=
2.0
np_array
[
0
,
0
]
=
2.0
np_array
[
2
,
8
]
=
4.0
np_array
[
2
,
8
]
=
4.0
tensor
=
selcted_rows
.
get_tensor
()
tensor
=
sel
e
cted_rows
.
get_tensor
()
tensor
.
set
(
np_array
,
place
)
tensor
.
set
(
np_array
,
place
)
# compare rows
# compare rows
self
.
assertEqual
(
0
,
selcted_rows
.
rows
()[
0
])
self
.
assertEqual
(
0
,
sel
e
cted_rows
.
rows
()[
0
])
self
.
assertEqual
(
4
,
selcted_rows
.
rows
()[
1
])
self
.
assertEqual
(
4
,
sel
e
cted_rows
.
rows
()[
1
])
self
.
assertEqual
(
7
,
selcted_rows
.
rows
()[
2
])
self
.
assertEqual
(
7
,
sel
e
cted_rows
.
rows
()[
2
])
# compare height
# compare height
self
.
assertEqual
(
10
,
selcted_rows
.
height
())
self
.
assertEqual
(
10
,
sel
e
cted_rows
.
height
())
# compare tensor
# compare tensor
self
.
assertAlmostEqual
(
2.0
,
self
.
assertAlmostEqual
(
2.0
,
selcted_rows
.
get_tensor
().
get_float_element
(
0
))
sel
e
cted_rows
.
get_tensor
().
get_float_element
(
0
))
self
.
assertAlmostEqual
(
1.0
,
self
.
assertAlmostEqual
(
1.0
,
selcted_rows
.
get_tensor
().
get_float_element
(
1
))
sel
e
cted_rows
.
get_tensor
().
get_float_element
(
1
))
self
.
assertAlmostEqual
(
self
.
assertAlmostEqual
(
4.0
,
selcted_rows
.
get_tensor
().
get_float_element
(
2
*
row_numel
+
8
))
4.0
,
selected_rows
.
get_tensor
().
get_float_element
(
2
*
row_numel
+
8
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
python/paddle/v2/framework/tests/test_sgd_op.py
浏览文件 @
7ad13fbf
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
import
paddle.v2.framework.core
as
core
from
paddle.v2.framework.op
import
Operator
from
op_test
import
OpTest
from
op_test
import
OpTest
...
@@ -17,5 +19,70 @@ class TestSGDOp(OpTest):
...
@@ -17,5 +19,70 @@ class TestSGDOp(OpTest):
self
.
check_output
()
self
.
check_output
()
class
TestSparseSGDOp
(
unittest
.
TestCase
):
def
check_with_place
(
self
,
place
):
scope
=
core
.
Scope
()
# create and initialize Grad Variable
height
=
10
rows
=
[
0
,
4
,
7
]
row_numel
=
12
grad_selected_rows
=
scope
.
var
(
'Grad'
).
get_selected_rows
()
grad_selected_rows
.
set_height
(
height
)
grad_selected_rows
.
set_rows
(
rows
)
np_array
=
np
.
ones
((
len
(
rows
),
row_numel
)).
astype
(
"float32"
)
np_array
[
0
,
0
]
=
2.0
np_array
[
2
,
8
]
=
4.0
grad_tensor
=
grad_selected_rows
.
get_tensor
()
grad_tensor
.
set
(
np_array
,
place
)
# create and initialize Param Variable
param
=
scope
.
var
(
'Param'
).
get_tensor
()
param_array
=
np
.
full
((
height
,
row_numel
),
5.0
).
astype
(
"float32"
)
param
.
set
(
param_array
,
place
)
# create and initialize LeraningRate Variable
lr
=
scope
.
var
(
'LearningRate'
).
get_tensor
()
lr_array
=
np
.
full
((
1
),
2.0
).
astype
(
"float32"
)
lr
.
set
(
lr_array
,
place
)
# create and run sgd operator
sgd_op
=
Operator
(
"sgd"
,
Param
=
'Param'
,
Grad
=
'Grad'
,
ParamOut
=
'Param'
,
LearningRate
=
'LearningRate'
)
ctx
=
core
.
DeviceContext
.
create
(
place
)
sgd_op
.
run
(
scope
,
ctx
)
# get and compare result
result_array
=
np
.
array
(
param
)
# rows[0] = 0, 5.0 - 2.0 * 2.0
self
.
assertAlmostEqual
(
1.0
,
result_array
[
rows
[
0
],
0
])
# rows[0] = 0, 5.0 - 2.0 * 1.0
self
.
assertAlmostEqual
(
3.0
,
result_array
[
rows
[
0
],
2
])
# 5.0 - 2.0 * 0.0
self
.
assertAlmostEqual
(
5.0
,
result_array
[
1
,
0
])
# rows[1] = 4, 5.0 - 2.0 * 1.0
self
.
assertAlmostEqual
(
3.0
,
result_array
[
rows
[
1
],
10
])
# 5.0 - 2.0 * 0.0
self
.
assertAlmostEqual
(
5.0
,
result_array
[
5
,
8
])
# rows[2] = 7, 5.0 - 2.0 * 1.0
self
.
assertAlmostEqual
(
3.0
,
result_array
[
rows
[
2
],
1
])
# rows[2] = 7, 5.0 - 2.0 * 4.0
self
.
assertAlmostEqual
(
-
3.0
,
result_array
[
rows
[
2
],
8
])
def
test_sparse_sgd
(
self
):
places
=
[
core
.
CPUPlace
()]
if
core
.
is_compile_gpu
():
places
.
append
(
core
.
GPUPlace
(
0
))
for
place
in
places
:
self
.
check_with_place
(
place
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录