Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
584c3f04
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
584c3f04
编写于
9月 29, 2018
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix sparse rmsprop
上级
23644940
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
276 addition
and
57 deletion
+276
-57
paddle/fluid/operators/adam_op.h
paddle/fluid/operators/adam_op.h
+3
-16
paddle/fluid/operators/math/algorithm.h
paddle/fluid/operators/math/algorithm.h
+44
-0
paddle/fluid/operators/rmsprop_op.h
paddle/fluid/operators/rmsprop_op.h
+229
-41
未找到文件。
paddle/fluid/operators/adam_op.h
浏览文件 @
584c3f04
...
@@ -18,6 +18,7 @@ limitations under the License. */
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include <vector>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/for_range.h"
...
@@ -199,23 +200,9 @@ struct SparseAdamFunctor {
...
@@ -199,23 +200,9 @@ struct SparseAdamFunctor {
row_numel_
(
row_numel
),
row_numel_
(
row_numel
),
row_count_
(
row_count
)
{}
row_count_
(
row_count
)
{}
inline
HOSTDEVICE
int64_t
BinarySearchInRows
(
int64_t
row
)
const
{
int64_t
beg
=
0
,
end
=
row_count_
-
1
;
while
(
beg
<=
end
)
{
auto
mid
=
((
beg
+
end
)
>>
1
);
if
(
rows_
[
mid
]
==
row
)
return
mid
;
else
if
(
rows_
[
mid
]
<
row
)
beg
=
mid
+
1
;
else
end
=
mid
-
1
;
}
return
-
1
;
}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
int64_t
row
=
i
/
row_numel_
;
auto
row_idx
=
auto
row_idx
=
BinarySearchInRows
(
row
);
math
::
BinarySearch
<
int64_t
>
(
rows_
,
row_count_
,
i
/
row_numel_
);
T
g
=
row_idx
>=
0
?
grad_
[
row_idx
*
row_numel_
+
i
%
row_numel_
]
:
0
;
T
g
=
row_idx
>=
0
?
grad_
[
row_idx
*
row_numel_
+
i
%
row_numel_
]
:
0
;
// The following code is the same as dense
// The following code is the same as dense
...
...
paddle/fluid/operators/math/algorithm.h
0 → 100644
浏览文件 @
584c3f04
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cstdint> // for int64_t
#include <numeric>
#include "paddle/fluid/platform/hostdevice.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
T
>
HOSTDEVICE
inline
int64_t
BinarySearch
(
const
T
*
x
,
int64_t
num
,
const
T
&
val
)
{
int64_t
beg
=
0
,
end
=
num
-
1
;
while
(
beg
<=
end
)
{
auto
mid
=
((
beg
+
end
)
>>
1
);
if
(
x
[
mid
]
==
val
)
return
mid
;
else
if
(
x
[
mid
]
<
val
)
beg
=
mid
+
1
;
else
end
=
mid
-
1
;
}
return
-
1
;
}
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/rmsprop_op.h
浏览文件 @
584c3f04
...
@@ -13,66 +13,254 @@ See the License for the specific language governing permissions and
...
@@ -13,66 +13,254 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include <math.h>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenVector
=
framework
::
EigenVector
<
T
,
MajorType
,
IndexType
>
;
using
EigenVector
=
framework
::
EigenVector
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
T
>
struct
DenseRmspropGradFunctor
{
inline
explicit
DenseRmspropGradFunctor
(
const
T
*
grad
)
:
grad_
(
grad
)
{}
HOSTDEVICE
inline
T
operator
()(
int64_t
idx
)
const
{
return
grad_
[
idx
];
}
const
T
*
grad_
;
};
template
<
typename
T
>
struct
SparseRmspropGradFunctor
{
inline
SparseRmspropGradFunctor
(
const
T
*
grad
,
const
int64_t
*
rows
,
int64_t
row_numel
,
int64_t
row_count
)
:
grad_
(
grad
),
rows_
(
rows
),
row_numel_
(
row_numel
),
row_count_
(
row_count
)
{}
HOSTDEVICE
inline
T
operator
()(
int64_t
idx
)
const
{
auto
row_idx
=
math
::
BinarySearch
(
rows_
,
row_count_
,
idx
/
row_numel_
);
return
row_idx
>=
0
?
grad_
[
row_idx
*
row_numel_
+
idx
%
row_numel_
]
:
0
;
}
const
T
*
grad_
;
const
int64_t
*
rows_
;
int64_t
row_numel_
;
int64_t
row_count_
;
};
template
<
typename
T
,
typename
GradFunctor
>
struct
UncenteredRmspropFunctor
{
UncenteredRmspropFunctor
(
T
*
param
,
T
*
ms
,
T
*
mom
,
const
T
*
lr
,
T
rho
,
T
epsilon
,
T
momentum
,
const
GradFunctor
&
grad_functor
)
:
param_
(
param
),
ms_
(
ms
),
mom_
(
mom
),
lr_
(
lr
),
rho_
(
rho
),
epsilon_
(
epsilon
),
momentum_
(
momentum
),
grad_functor_
(
grad_functor
)
{}
HOSTDEVICE
inline
void
operator
()(
int64_t
idx
)
const
{
T
g
=
grad_functor_
(
idx
);
T
ms_out
=
rho_
*
ms_
[
idx
]
+
(
1
-
rho_
)
*
g
*
g
;
T
mom_out
=
momentum_
*
mom_
[
idx
]
+
lr_
[
0
]
*
g
/
sqrt
(
ms_out
+
epsilon_
);
param_
[
idx
]
-=
mom_out
;
ms_
[
idx
]
=
ms_out
;
mom_
[
idx
]
=
mom_out
;
}
T
*
param_
;
T
*
ms_
;
T
*
mom_
;
const
T
*
lr_
;
T
rho_
;
T
epsilon_
;
T
momentum_
;
GradFunctor
grad_functor_
;
};
template
<
typename
T
,
typename
GradFunctor
>
struct
CenteredRmspropFunctor
{
CenteredRmspropFunctor
(
T
*
param
,
T
*
ms
,
T
*
mom
,
T
*
mean_grad
,
const
T
*
lr
,
T
rho
,
T
epsilon
,
T
momentum
,
const
GradFunctor
&
grad_functor
)
:
param_
(
param
),
ms_
(
ms
),
mom_
(
mom
),
mean_grad_
(
mean_grad
),
lr_
(
lr
),
rho_
(
rho
),
epsilon_
(
epsilon
),
momentum_
(
momentum
),
grad_functor_
(
grad_functor
)
{}
HOSTDEVICE
inline
void
operator
()(
int64_t
idx
)
const
{
T
g
=
grad_functor_
(
idx
);
T
ms_out
=
rho_
*
ms_
[
idx
]
+
(
1
-
rho_
)
*
g
*
g
;
T
mg_out
=
rho_
*
mean_grad_
[
idx
]
+
(
1
-
rho_
)
*
g
;
T
mom_out
=
momentum_
*
mom_
[
idx
]
+
lr_
[
0
]
*
g
/
sqrt
(
ms_out
-
mg_out
*
mg_out
+
epsilon_
);
param_
[
idx
]
-=
mom_out
;
ms_
[
idx
]
=
ms_out
;
mom_
[
idx
]
=
mom_out
;
mean_grad_
[
idx
]
=
mg_out
;
}
T
*
param_
;
T
*
ms_
;
T
*
mom_
;
T
*
mean_grad_
;
const
T
*
lr_
;
T
rho_
;
T
epsilon_
;
T
momentum_
;
GradFunctor
grad_functor_
;
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
RmspropOpKernel
:
public
framework
::
OpKernel
<
T
>
{
class
RmspropOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
param_out
=
ctx
.
Output
<
Tensor
>
(
"ParamOut"
);
using
Tensor
=
framework
::
LoDTensor
;
auto
*
moment_out
=
ctx
.
Output
<
Tensor
>
(
"MomentOut"
);
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
auto
*
mean_square_out
=
ctx
.
Output
<
Tensor
>
(
"MeanSquareOut"
);
auto
*
param_out
=
ctx
.
Output
<
Tensor
>
(
"ParamOut"
);
auto
*
moment_out
=
ctx
.
Output
<
Tensor
>
(
"MomentOut"
);
auto
*
mean_square_out
=
ctx
.
Output
<
Tensor
>
(
"MeanSquareOut"
);
auto
grad
=
ctx
.
Input
<
Tensor
>
(
"Grad"
);
auto
epsilon
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"epsilon"
));
auto
rho
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"decay"
));
auto
momentum
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"momentum"
));
bool
centered
=
ctx
.
Attr
<
bool
>
(
"centered"
);
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
p_tensor
=
*
ctx
.
Input
<
Tensor
>
(
"Param"
);
moment_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
ms_tensor
=
*
ctx
.
Input
<
Tensor
>
(
"MeanSquare"
);
mean_square_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
lr_tensor
=
*
ctx
.
Input
<
Tensor
>
(
"LearningRate"
);
auto
&
mom_tensor
=
*
ctx
.
Input
<
Tensor
>
(
"Moment"
);
float
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
PADDLE_ENFORCE_EQ
(
&
p_tensor
,
param_out
,
float
rho
=
ctx
.
Attr
<
float
>
(
"decay"
);
"Param and ParamOut must be the same Tensor"
);
float
momentum
=
ctx
.
Attr
<
float
>
(
"momentum"
);
PADDLE_ENFORCE_EQ
(
&
mom_tensor
,
moment_out
,
bool
centered
=
ctx
.
Attr
<
bool
>
(
"centered"
);
"Moment and MomentOut must be the same Tensor"
);
PADDLE_ENFORCE_EQ
(
&
ms_tensor
,
mean_square_out
,
"MeanSquare and MeanSquareOut must be the same Tensor"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
size_t
limit
=
static_cast
<
size_t
>
(
ms_tensor
.
numel
());
if
(
grad_var
->
IsType
<
Tensor
>
())
{
auto
&
grad_tensor
=
grad_var
->
Get
<
Tensor
>
();
if
(
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
)
{
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
lr_value
=
lr_tensor
.
data
<
T
>
()[
0
];
auto
p
=
EigenVector
<
T
>::
Flatten
(
p_tensor
);
auto
ms
=
EigenVector
<
T
>::
Flatten
(
ms_tensor
);
auto
g
=
EigenVector
<
T
>::
Flatten
(
grad_tensor
);
auto
mom
=
EigenVector
<
T
>::
Flatten
(
mom_tensor
);
auto
p_out
=
EigenVector
<
T
>::
Flatten
(
*
param_out
);
auto
mom_out
=
EigenVector
<
T
>::
Flatten
(
*
moment_out
);
auto
ms_out
=
EigenVector
<
T
>::
Flatten
(
*
mean_square_out
);
ms_out
.
device
(
place
)
=
rho
*
ms
+
(
1
-
rho
)
*
g
*
g
;
if
(
centered
)
{
auto
&
mg_tensor
=
*
ctx
.
Input
<
Tensor
>
(
"MeanGrad"
);
auto
mg
=
EigenVector
<
T
>::
Flatten
(
mg_tensor
);
auto
*
mean_grad_out
=
ctx
.
Output
<
Tensor
>
(
"MeanGradOut"
);
PADDLE_ENFORCE
(
&
mg_tensor
,
mean_grad_out
,
"MeanGrad and MeanGradOut must be the same Tensor"
);
auto
mg_out
=
EigenVector
<
T
>::
Flatten
(
*
mean_grad_out
);
mg_out
.
device
(
place
)
=
rho
*
mg
+
(
1
-
rho
)
*
g
;
mom_out
.
device
(
place
)
=
momentum
*
mom
+
lr_value
*
g
/
(
ms_out
-
mg_out
.
square
()
+
epsilon
).
sqrt
();
}
else
{
mom_out
.
device
(
place
)
=
momentum
*
mom
+
lr_value
*
g
/
(
ms_out
+
epsilon
).
sqrt
();
}
p_out
.
device
(
place
)
=
p
-
mom_out
;
}
else
{
DenseRmspropGradFunctor
<
T
>
grad_func
(
grad_tensor
.
data
<
T
>
());
platform
::
ForRange
<
DeviceContext
>
for_range
(
dev_ctx
,
limit
);
if
(
centered
)
{
auto
&
mg_tensor
=
*
ctx
.
Input
<
Tensor
>
(
"MeanGrad"
);
auto
*
mean_grad_out
=
ctx
.
Output
<
Tensor
>
(
"MeanGradOut"
);
PADDLE_ENFORCE
(
&
mg_tensor
,
mean_grad_out
,
"MeanGrad and MeanGradOut must be the same Tensor"
);
for_range
(
CenteredRmspropFunctor
<
T
,
DenseRmspropGradFunctor
<
T
>>
(
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
mean_square_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
moment_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
mean_grad_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
lr_tensor
.
data
<
T
>
(),
rho
,
epsilon
,
momentum
,
grad_func
));
}
else
{
for_range
(
UncenteredRmspropFunctor
<
T
,
DenseRmspropGradFunctor
<
T
>>
(
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
mean_square_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
moment_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
lr_tensor
.
data
<
T
>
(),
rho
,
epsilon
,
momentum
,
grad_func
));
}
}
}
else
if
(
grad_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
&
grad
=
grad_var
->
Get
<
framework
::
SelectedRows
>
();
auto
*
merged_grad
=
const_cast
<
framework
::
Scope
&>
(
ctx
.
scope
())
.
Var
()
->
GetMutable
<
framework
::
SelectedRows
>
();
math
::
scatter
::
MergeAdd
<
DeviceContext
,
T
>
merge_func
;
merge_func
(
dev_ctx
,
grad
,
merged_grad
);
platform
::
ForRange
<
DeviceContext
>
for_range
(
dev_ctx
,
limit
);
const
int64_t
*
rows
;
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
rows
=
merged_grad
->
rows
().
CUDAData
(
ctx
.
GetPlace
());
}
else
{
#endif
rows
=
merged_grad
->
rows
().
data
();
#ifdef PADDLE_WITH_CUDA
}
#endif
auto
&
merged_tensor
=
merged_grad
->
value
();
int64_t
row_count
=
merged_grad
->
rows
().
size
();
int64_t
row_numel
=
merged_tensor
.
numel
()
/
row_count
;
SparseRmspropGradFunctor
<
T
>
grad_func
(
merged_tensor
.
data
<
T
>
(),
rows
,
row_numel
,
row_count
);
auto
p
=
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
Tensor
>
(
"Param"
));
if
(
centered
)
{
auto
ms
=
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
Tensor
>
(
"MeanSquare"
));
auto
&
mg_tensor
=
*
ctx
.
Input
<
Tensor
>
(
"MeanGrad"
);
auto
lr
=
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
Tensor
>
(
"LearningRate"
));
auto
*
mean_grad_out
=
ctx
.
Output
<
Tensor
>
(
"MeanGradOut"
);
auto
g
=
EigenVector
<
T
>::
Flatten
(
*
grad
);
PADDLE_ENFORCE
(
&
mg_tensor
,
mean_grad_out
,
auto
mom
=
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
Tensor
>
(
"Moment"
));
"MeanGrad and MeanGradOut must be the same Tensor"
);
for_range
(
CenteredRmspropFunctor
<
T
,
SparseRmspropGradFunctor
<
T
>>
(
auto
p_out
=
EigenVector
<
T
>::
Flatten
(
*
param_out
);
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
auto
mom_out
=
EigenVector
<
T
>::
Flatten
(
*
moment_out
);
mean_square_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
auto
ms_out
=
EigenVector
<
T
>::
Flatten
(
*
mean_square_out
);
moment_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
mean_grad_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
lr_tensor
.
data
<
T
>
(),
rho
,
epsilon
,
momentum
,
grad_func
));
Eigen
::
DSizes
<
int
,
1
>
grad_dsize
(
static_cast
<
int
>
(
grad
->
numel
()));
}
else
{
for_range
(
UncenteredRmspropFunctor
<
T
,
SparseRmspropGradFunctor
<
T
>>
(
ms_out
.
device
(
place
)
=
rho
*
ms
+
(
1
-
rho
)
*
g
*
g
;
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
if
(
centered
)
{
mean_square_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
auto
mg
=
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
Tensor
>
(
"MeanGrad"
));
moment_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
lr_tensor
.
data
<
T
>
(),
auto
*
mean_grad_out
=
ctx
.
Output
<
Tensor
>
(
"MeanGradOut"
);
rho
,
epsilon
,
momentum
,
grad_func
));
mean_grad_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
}
auto
mg_out
=
EigenVector
<
T
>::
Flatten
(
*
mean_grad_out
);
mg_out
.
device
(
place
)
=
rho
*
mg
+
(
1
-
rho
)
*
g
;
mom_out
.
device
(
place
)
=
momentum
*
mom
+
lr
.
broadcast
(
grad_dsize
)
*
g
/
(
ms_out
-
mg_out
.
square
()
+
epsilon
).
sqrt
();
}
else
{
}
else
{
mom_out
.
device
(
place
)
=
PADDLE_THROW
(
"RMSProp only supports LoDTensor or SelectedRows gradient"
);
momentum
*
mom
+
lr
.
broadcast
(
grad_dsize
)
*
g
/
(
ms_out
+
epsilon
).
sqrt
();
}
}
p_out
.
device
(
place
)
=
p
-
mom_out
;
}
}
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录