Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
02fda711
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
02fda711
编写于
12月 23, 2017
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine sgd-op
上级
bb58a474
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
94 addition
and
83 deletion
+94
-83
paddle/operators/sgd_op.cc
paddle/operators/sgd_op.cc
+1
-35
paddle/operators/sgd_op.cu
paddle/operators/sgd_op.cu
+69
-31
paddle/operators/sgd_op.h
paddle/operators/sgd_op.h
+24
-17
未找到文件。
paddle/operators/sgd_op.cc
浏览文件 @
02fda711
...
...
@@ -61,43 +61,9 @@ $$param\_out = param - learning\_rate * grad$$
}
};
template
<
typename
T
>
struct
SparseSGDFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
,
const
framework
::
Tensor
&
learning_rate
,
framework
::
Tensor
*
output
)
{
auto
in_height
=
input
.
height
();
auto
out_dims
=
output
->
dims
();
PADDLE_ENFORCE_EQ
(
in_height
,
out_dims
[
0
]);
auto
&
in_value
=
input
.
value
();
auto
&
in_rows
=
input
.
rows
();
int64_t
in_row_numel
=
in_value
.
numel
()
/
in_rows
.
size
();
PADDLE_ENFORCE_EQ
(
in_row_numel
,
output
->
numel
()
/
in_height
);
auto
*
in_data
=
in_value
.
data
<
T
>
();
auto
*
out_data
=
output
->
data
<
T
>
();
auto
*
lr
=
learning_rate
.
data
<
T
>
();
for
(
size_t
i
=
0
;
i
<
in_rows
.
size
();
i
++
)
{
for
(
int64_t
j
=
0
;
j
<
in_row_numel
;
j
++
)
{
out_data
[
in_rows
[
i
]
*
in_row_numel
+
j
]
-=
lr
[
0
]
*
in_data
[
i
*
in_row_numel
+
j
];
}
}
}
};
template
struct
SparseSGDFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
SparseSGDFunctor
<
platform
::
CPUDeviceContext
,
double
>;
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
sgd
,
ops
::
SGDOp
,
ops
::
SGDOpMaker
);
REGISTER_OP_CPU_KERNEL
(
sgd
,
ops
::
SGDOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
SGDOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
sgd
,
ops
::
SGDOpKernel
<
float
>
,
ops
::
SGDOpKernel
<
double
>
);
paddle/operators/sgd_op.cu
浏览文件 @
02fda711
...
...
@@ -20,6 +20,19 @@ namespace paddle {
namespace
operators
{
namespace
{
template
<
typename
T
>
__global__
void
SGDKernel
(
const
T
*
g
,
const
T
*
p
,
const
T
*
learning_rate
,
const
int
num
,
T
*
p_out
)
{
T
lr
=
learning_rate
[
0
];
int
grid_size
=
blockDim
.
x
*
gridDim
.
x
;
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
num
;
i
+=
grid_size
)
{
T
g_data
=
g
[
i
];
T
p_data
=
p
[
i
];
p_out
[
i
]
=
p_data
-
lr
*
g_data
;
}
}
template
<
typename
T
,
int
block_size
>
__global__
void
SparseSGDFunctorKernel
(
const
T
*
selected_rows
,
const
int64_t
*
rows
,
...
...
@@ -41,40 +54,65 @@ __global__ void SparseSGDFunctorKernel(const T* selected_rows,
}
// namespace
template
<
typename
T
>
struct
SparseSGDFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
,
const
framework
::
Tensor
&
learning_rate
,
framework
::
Tensor
*
output
)
{
auto
in_height
=
input
.
height
();
auto
out_dims
=
output
->
dims
();
PADDLE_ENFORCE_EQ
(
in_height
,
out_dims
[
0
]);
auto
&
in_value
=
input
.
value
();
auto
&
in_rows
=
input
.
rows
();
int64_t
in_row_numel
=
in_value
.
numel
()
/
in_rows
.
size
();
PADDLE_ENFORCE_EQ
(
in_row_numel
,
output
->
numel
()
/
in_height
);
auto
*
in_data
=
in_value
.
data
<
T
>
();
auto
*
out_data
=
output
->
data
<
T
>
();
const
int
block_size
=
256
;
dim3
threads
(
block_size
,
1
);
dim3
grid
(
1
,
in_rows
.
size
());
SparseSGDFunctorKernel
<
T
,
256
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
in_data
,
in_rows
.
data
(),
learning_rate
.
data
<
T
>
(),
out_data
,
in_row_numel
);
class
SGDOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
auto
*
param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
*
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
// Actually, all tensors are LoDTensor except SelectedRows.
if
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
())
{
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
);
auto
*
grad_data
=
grad
->
data
<
T
>
();
auto
*
param_data
=
param
->
data
<
T
>
();
auto
*
param_out_data
=
param_out
->
data
<
T
>
();
int
block
=
512
;
int
grid
=
(
param
->
numel
()
+
block
-
1
)
/
block
;
SGDKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
grad_data
,
param_data
,
learning_rate
->
data
<
T
>
(),
param
->
numel
(),
param_out_data
);
}
else
if
(
grad_var
->
IsType
<
framework
::
SelectedRows
>
())
{
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
// This manual optimization brings difficulty to track data dependency.
// It's better to find a more elegant solution.
PADDLE_ENFORCE_EQ
(
param
,
param_out
);
auto
*
grad
=
ctx
.
Input
<
framework
::
SelectedRows
>
(
"Grad"
);
auto
in_height
=
grad
->
height
();
auto
out_dims
=
param_out
->
dims
();
PADDLE_ENFORCE_EQ
(
in_height
,
out_dims
[
0
]);
auto
&
in_value
=
grad
->
value
();
auto
&
in_rows
=
grad
->
rows
();
int64_t
in_row_numel
=
in_value
.
numel
()
/
in_rows
.
size
();
PADDLE_ENFORCE_EQ
(
in_row_numel
,
param_out
->
numel
()
/
in_height
);
auto
*
in_data
=
in_value
.
data
<
T
>
();
auto
*
out_data
=
param_out
->
data
<
T
>
();
const
int
block_size
=
256
;
dim3
threads
(
block_size
,
1
);
dim3
grid
(
1
,
in_rows
.
size
());
SparseSGDFunctorKernel
<
T
,
256
><<<
grid
,
threads
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
in_data
,
in_rows
.
data
(),
learning_rate
->
data
<
T
>
(),
out_data
,
in_row_numel
);
}
else
{
PADDLE_THROW
(
"Unsupported Variable Type of Grad"
);
}
}
};
template
struct
SparseSGDFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
SparseSGDFunctor
<
platform
::
CUDADeviceContext
,
double
>;
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
sgd
,
ops
::
SGDOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
SGDOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
sgd
,
ops
::
SGDOpCUDAKernel
<
float
>
,
ops
::
SGDOpCUDAKernel
<
double
>
);
paddle/operators/sgd_op.h
浏览文件 @
02fda711
...
...
@@ -20,15 +20,7 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
struct
SparseSGDFunctor
{
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
,
const
framework
::
Tensor
&
learning_rate
,
framework
::
Tensor
*
output
);
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
T
>
class
SGDOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
...
@@ -45,21 +37,36 @@ class SGDOpKernel : public framework::OpKernel<T> {
auto
p
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param
);
auto
g
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
grad
);
auto
o
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param_out
);
auto
lr
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
learning_rate
);
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
*
lr
=
learning_rate
->
data
<
T
>
();
Eigen
::
DSizes
<
int
,
1
>
grad_dsize
(
grad
->
numel
());
o
.
device
(
place
)
=
p
-
lr
.
broadcast
(
grad_dsize
)
*
g
;
o
=
p
-
lr
[
0
]
*
g
;
}
else
if
(
grad_var
->
IsType
<
framework
::
SelectedRows
>
())
{
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
// This manual optimization brings difficulty to track data dependency.
// It's better to find a more elegant solution.
PADDLE_ENFORCE_EQ
(
param
,
param_out
);
auto
*
grad
=
ctx
.
Input
<
framework
::
SelectedRows
>
(
"Grad"
);
SparseSGDFunctor
<
DeviceContext
,
T
>
functor
;
functor
(
ctx
.
template
device_context
<
DeviceContext
>(),
*
grad
,
*
learning_rate
,
param_out
);
auto
in_height
=
grad
->
height
();
auto
out_dims
=
param_out
->
dims
();
PADDLE_ENFORCE_EQ
(
in_height
,
out_dims
[
0
]);
auto
&
in_value
=
grad
->
value
();
auto
&
in_rows
=
grad
->
rows
();
int64_t
in_row_numel
=
in_value
.
numel
()
/
in_rows
.
size
();
PADDLE_ENFORCE_EQ
(
in_row_numel
,
param_out
->
numel
()
/
in_height
);
auto
*
in_data
=
in_value
.
data
<
T
>
();
auto
*
out_data
=
param_out
->
data
<
T
>
();
auto
*
lr
=
learning_rate
->
data
<
T
>
();
for
(
size_t
i
=
0
;
i
<
in_rows
.
size
();
i
++
)
{
for
(
int64_t
j
=
0
;
j
<
in_row_numel
;
j
++
)
{
out_data
[
in_rows
[
i
]
*
in_row_numel
+
j
]
-=
lr
[
0
]
*
in_data
[
i
*
in_row_numel
+
j
];
}
}
}
else
{
PADDLE_THROW
(
"Unsupported Variable Type of Grad"
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录