Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b4a9c184
P
Paddle
项目概览
PaddlePaddle
/
Paddle
8 个月 前同步成功
通知
2282
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
b4a9c184
编写于
10月 18, 2018
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
差异文件
test=release/1.0.0
上级
a4f5cc03
59577786
变更
24
隐藏空白更改
内联
并排
Showing
24 changed file
with
731 addition
and
138 deletion
+731
-138
paddle/fluid/operators/adadelta_op.cc
paddle/fluid/operators/adadelta_op.cc
+12
-0
paddle/fluid/operators/adadelta_op.h
paddle/fluid/operators/adadelta_op.h
+11
-0
paddle/fluid/operators/adagrad_op.h
paddle/fluid/operators/adagrad_op.h
+20
-13
paddle/fluid/operators/adam_op.h
paddle/fluid/operators/adam_op.h
+6
-0
paddle/fluid/operators/adamax_op.cc
paddle/fluid/operators/adamax_op.cc
+10
-0
paddle/fluid/operators/adamax_op.h
paddle/fluid/operators/adamax_op.h
+11
-0
paddle/fluid/operators/decayed_adagrad_op.cc
paddle/fluid/operators/decayed_adagrad_op.cc
+10
-0
paddle/fluid/operators/decayed_adagrad_op.h
paddle/fluid/operators/decayed_adagrad_op.h
+11
-0
paddle/fluid/operators/fill_constant_op.cc
paddle/fluid/operators/fill_constant_op.cc
+8
-1
paddle/fluid/operators/ftrl_op.cc
paddle/fluid/operators/ftrl_op.cc
+10
-0
paddle/fluid/operators/ftrl_op.h
paddle/fluid/operators/ftrl_op.h
+11
-0
paddle/fluid/operators/momentum_op.cc
paddle/fluid/operators/momentum_op.cc
+45
-13
paddle/fluid/operators/momentum_op.cu
paddle/fluid/operators/momentum_op.cu
+3
-61
paddle/fluid/operators/momentum_op.h
paddle/fluid/operators/momentum_op.h
+311
-14
paddle/fluid/operators/rmsprop_op.cc
paddle/fluid/operators/rmsprop_op.cc
+5
-0
paddle/fluid/operators/rmsprop_op.h
paddle/fluid/operators/rmsprop_op.h
+5
-0
paddle/fluid/operators/sgd_op.cc
paddle/fluid/operators/sgd_op.cc
+16
-13
paddle/fluid/operators/sgd_op.cu
paddle/fluid/operators/sgd_op.cu
+6
-0
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+6
-2
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+15
-2
python/paddle/fluid/tests/unittests/dist_simnet_bow.py
python/paddle/fluid/tests/unittests/dist_simnet_bow.py
+17
-5
python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py
python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py
+74
-4
python/paddle/fluid/tests/unittests/test_momentum_op.py
python/paddle/fluid/tests/unittests/test_momentum_op.py
+94
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+14
-10
未找到文件。
paddle/fluid/operators/adadelta_op.cc
浏览文件 @
b4a9c184
...
...
@@ -18,6 +18,7 @@ namespace paddle {
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
class
AdadeltaOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
...
...
@@ -31,6 +32,16 @@ class AdadeltaOp : public framework::OperatorWithKernel {
"Input(AvgSquaredGrad) of AdadeltaOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"AvgSquaredUpdate"
),
"Input(AvgSquaredUpdate) of AdadeltaOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Param"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input var's type should be LoDTensor, but the received is %s"
,
ctx
->
Inputs
(
"Param"
).
front
(),
ctx
->
GetInputsVarType
(
"Param"
).
front
());
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Grad"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input var's type should be LoDTensor, but the received is %s"
,
ctx
->
Inputs
(
"Grad"
).
front
(),
ctx
->
GetInputsVarType
(
"Grad"
).
front
());
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ParamOut"
),
"Output(ParamOut) of AdadeltaOp should not be null."
);
...
...
@@ -56,6 +67,7 @@ class AdadeltaOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"AvgSquaredGradOut"
,
param_dim
);
ctx
->
SetOutputDim
(
"AvgSquaredUpdateOut"
,
param_dim
);
}
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
...
...
paddle/fluid/operators/adadelta_op.h
浏览文件 @
b4a9c184
...
...
@@ -23,6 +23,17 @@ template <typename DeviceContext, typename T>
class
AdadeltaOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
PADDLE_ENFORCE
(
param_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Param"
).
front
(),
param_var
->
Type
().
name
());
const
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
PADDLE_ENFORCE
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Grad"
).
front
(),
grad_var
->
Type
().
name
());
auto
param_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
avg_squared_grad_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"AvgSquaredGradOut"
);
...
...
paddle/fluid/operators/adagrad_op.h
浏览文件 @
b4a9c184
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
...
...
@@ -21,25 +22,31 @@ namespace operators {
template
<
typename
DeviceContext
,
typename
T
>
struct
SparseAdagradFunctor
{
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
grad
,
const
framework
::
Tensor
&
learning_rate
,
T
epsilon
,
framework
::
Tensor
*
moment
,
framework
::
Tensor
*
param
);
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
grad
,
const
framework
::
Tensor
&
learning_rate
,
T
epsilon
,
framework
::
Tensor
*
moment
,
framework
::
Tensor
*
param
);
};
template
<
typename
DeviceContext
,
typename
T
>
class
AdagradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
param_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
*
moment_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"MomentOut"
);
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
PADDLE_ENFORCE
(
param_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Param"
).
front
(),
param_var
->
Type
().
name
());
auto
*
param_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
*
moment_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"MomentOut"
);
param_out_tensor
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
moment_out_tensor
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
epsilon
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"epsilon"
));
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
if
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
param
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
));
...
...
@@ -47,16 +54,16 @@ class AdagradOpKernel : public framework::OpKernel<T> {
*
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
));
auto
moment
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
framework
::
Tensor
>
(
"Moment"
));
auto
*
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
auto
*
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
auto
param_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param_out_tensor
);
auto
moment_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
moment_out_tensor
);
auto
*
place
=
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
*
place
=
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
moment_out
.
device
(
*
place
)
=
moment
+
grad
*
grad
;
Eigen
::
DSizes
<
int
,
1
>
m_dsize
(
moment_out_tensor
->
numel
());
if
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()))
{
auto
*
lr
=
learning_rate
->
data
<
T
>
();
auto
*
lr
=
learning_rate
->
data
<
T
>
();
param_out
.
device
(
*
place
)
=
param
-
lr
[
0
]
*
grad
/
(
moment_out
.
sqrt
()
+
epsilon
);
}
else
{
...
...
@@ -66,10 +73,10 @@ class AdagradOpKernel : public framework::OpKernel<T> {
lr
.
broadcast
(
m_dsize
)
*
grad
/
(
moment_out
.
sqrt
()
+
epsilon
);
}
}
else
if
(
grad_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
*
param_tensor
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
auto
*
param_tensor
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
PADDLE_ENFORCE_EQ
(
param_tensor
,
param_out_tensor
);
auto
*
moment_tensor
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Moment"
);
auto
*
moment_tensor
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Moment"
);
PADDLE_ENFORCE_EQ
(
moment_tensor
,
moment_out_tensor
);
SparseAdagradFunctor
<
DeviceContext
,
T
>
functor
;
...
...
paddle/fluid/operators/adam_op.h
浏览文件 @
b4a9c184
...
...
@@ -231,6 +231,12 @@ template <typename DeviceContext, typename T>
class
AdamOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
PADDLE_ENFORCE
(
param_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Param"
).
front
(),
param_var
->
Type
().
name
());
using
paddle
::
framework
::
LoDTensor
;
using
paddle
::
operators
::
detail
::
Ref
;
...
...
paddle/fluid/operators/adamax_op.cc
浏览文件 @
b4a9c184
...
...
@@ -35,6 +35,16 @@ class AdamaxOp : public framework::OperatorWithKernel {
"Input(LearningRate) of AdamaxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Beta1Pow"
),
"Input(Beta1Pow) of AdamaxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Param"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input var's type should be LoDTensor, but the received is %s"
,
ctx
->
Inputs
(
"Param"
).
front
(),
ctx
->
GetInputsVarType
(
"Param"
).
front
());
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Grad"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input var's type should be LoDTensor, but the received is %s"
,
ctx
->
Inputs
(
"Grad"
).
front
(),
ctx
->
GetInputsVarType
(
"Grad"
).
front
());
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ParamOut"
),
"Output(ParamOut) of AdamaxOp should not be null."
);
...
...
paddle/fluid/operators/adamax_op.h
浏览文件 @
b4a9c184
...
...
@@ -23,6 +23,17 @@ template <typename DeviceContext, typename T>
class
AdamaxOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
PADDLE_ENFORCE
(
param_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Param"
).
front
(),
param_var
->
Type
().
name
());
const
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
PADDLE_ENFORCE
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Grad"
).
front
(),
grad_var
->
Type
().
name
());
auto
param_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
moment_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"MomentOut"
);
auto
inf_norm_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"InfNormOut"
);
...
...
paddle/fluid/operators/decayed_adagrad_op.cc
浏览文件 @
b4a9c184
...
...
@@ -32,6 +32,16 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"LearningRate"
),
"Input(LearningRate) of DecayedAdagradOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Param"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input var's type should be LoDTensor, but the received is %s"
,
ctx
->
Inputs
(
"Param"
).
front
(),
ctx
->
GetInputsVarType
(
"Param"
).
front
());
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Grad"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input var's type should be LoDTensor, but the received is %s"
,
ctx
->
Inputs
(
"Grad"
).
front
(),
ctx
->
GetInputsVarType
(
"Grad"
).
front
());
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ParamOut"
),
"Output(ParamOut) of DecayedAdagradOp should not be null."
);
...
...
paddle/fluid/operators/decayed_adagrad_op.h
浏览文件 @
b4a9c184
...
...
@@ -23,6 +23,17 @@ template <typename DeviceContext, typename T>
class
DecayedAdagradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
PADDLE_ENFORCE
(
param_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Param"
).
front
(),
param_var
->
Type
().
name
());
const
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
PADDLE_ENFORCE
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Grad"
).
front
(),
grad_var
->
Type
().
name
());
auto
param_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
moment_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"MomentOut"
);
...
...
paddle/fluid/operators/fill_constant_op.cc
浏览文件 @
b4a9c184
...
...
@@ -70,6 +70,12 @@ class FillConstantOp : public framework::OperatorBase {
}
};
class
FillConstantOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{}
};
class
FillConstantOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
...
...
@@ -102,4 +108,5 @@ Fill up a variable with specified constant value.
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
fill_constant
,
ops
::
FillConstantOp
,
ops
::
FillConstantInferShape
,
ops
::
FillConstantOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
FillConstantOpVarTypeInference
);
paddle/fluid/operators/ftrl_op.cc
浏览文件 @
b4a9c184
...
...
@@ -34,6 +34,16 @@ class FTRLOp : public framework::OperatorWithKernel {
"Input(Grad) of FTRL should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"LearningRate"
),
"Input(LearningRate) of FTRL should not be null."
);
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Param"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input var's type should be LoDTensor, but the received is %s"
,
ctx
->
Inputs
(
"Param"
).
front
(),
ctx
->
GetInputsVarType
(
"Param"
).
front
());
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Grad"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input var's type should be LoDTensor, but the received is %s"
,
ctx
->
Inputs
(
"Grad"
).
front
(),
ctx
->
GetInputsVarType
(
"Grad"
).
front
());
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ParamOut"
),
"Output(ParamOut) of FTRL should not be null."
);
...
...
paddle/fluid/operators/ftrl_op.h
浏览文件 @
b4a9c184
...
...
@@ -28,6 +28,17 @@ template <typename DeviceContext, typename T>
class
FTRLOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
PADDLE_ENFORCE
(
param_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Param"
).
front
(),
param_var
->
Type
().
name
());
const
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
PADDLE_ENFORCE
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Grad"
).
front
(),
grad_var
->
Type
().
name
());
auto
*
param_out
=
ctx
.
Output
<
Tensor
>
(
"ParamOut"
);
auto
*
sq_accum_out
=
ctx
.
Output
<
Tensor
>
(
"SquaredAccumOut"
);
auto
*
lin_accum_out
=
ctx
.
Output
<
Tensor
>
(
"LinearAccumOut"
);
...
...
paddle/fluid/operators/momentum_op.cc
浏览文件 @
b4a9c184
...
...
@@ -24,7 +24,7 @@ class MomentumOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(param) of Momentum should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
...
...
@@ -33,6 +33,11 @@ class MomentumOp : public framework::OperatorWithKernel {
"Input(velocity) of Momentum should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"LearningRate"
),
"Input(LearningRate) of Momentum should not be null."
);
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Param"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input var's type should be LoDTensor, but the received is %s"
,
ctx
->
Inputs
(
"Param"
).
front
(),
ctx
->
GetInputsVarType
(
"Param"
).
front
());
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ParamOut"
),
"Output(ParamOut) of Momentum should not be null."
);
...
...
@@ -40,12 +45,15 @@ class MomentumOp : public framework::OperatorWithKernel {
"Output(VelocityOut) of Momentum should not be null."
);
auto
param_dim
=
ctx
->
GetInputDim
(
"Param"
);
PADDLE_ENFORCE_EQ
(
param_dim
,
ctx
->
GetInputDim
(
"Grad"
),
"Param and Grad input of MomentumOp should have the same dimension."
);
PADDLE_ENFORCE_EQ
(
param_dim
,
ctx
->
GetInputDim
(
"Velocity"
),
"Param and Velocity of MomentumOp should have the same dimension."
);
if
(
ctx
->
GetInputsVarType
(
"Grad"
)[
0
]
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
PADDLE_ENFORCE_EQ
(
param_dim
,
ctx
->
GetInputDim
(
"Grad"
),
"Param and Grad input of MomentumOp should have the same dimension."
);
PADDLE_ENFORCE_EQ
(
param_dim
,
ctx
->
GetInputDim
(
"Velocity"
),
"Param and Velocity of MomentumOp should have the same dimension."
);
}
PADDLE_ENFORCE_EQ
(
framework
::
product
(
ctx
->
GetInputDim
(
"LearningRate"
)),
1
,
"Learning_rate should be a scalar"
);
...
...
@@ -53,13 +61,34 @@ class MomentumOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"VelocityOut"
,
param_dim
);
}
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Param"
)
->
type
());
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
framework
::
GetDataTypeOfVar
(
ctx
.
InputVar
(
"Param"
));
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
};
class
MomentumOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
input_var
=
op_desc
.
Input
(
"Param"
)[
0
];
for
(
auto
&
out_var
:
op_desc
.
Output
(
"ParamOut"
))
{
if
(
block
->
FindRecursiveOrCreateVar
(
input_var
).
GetType
()
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
block
->
FindRecursiveOrCreateVar
(
out_var
).
SetType
(
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
if
(
block
->
FindRecursiveOrCreateVar
(
input_var
).
GetType
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
block
->
FindRecursiveOrCreateVar
(
out_var
).
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
else
{
PADDLE_THROW
(
"Only support LodTensor and SelectedRows, Unexpected Input Type."
);
}
}
}
};
class
MomentumOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
...
...
@@ -110,6 +139,9 @@ $$
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
momentum
,
ops
::
MomentumOp
,
ops
::
MomentumOpMaker
);
REGISTER_OP_CPU_KERNEL
(
momentum
,
ops
::
MomentumOpKernel
<
float
>
,
ops
::
MomentumOpKernel
<
double
>
);
REGISTER_OPERATOR
(
momentum
,
ops
::
MomentumOp
,
ops
::
MomentumOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
MomentumOpInferVarType
);
REGISTER_OP_CPU_KERNEL
(
momentum
,
ops
::
MomentumOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
MomentumOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/momentum_op.cu
浏览文件 @
b4a9c184
...
...
@@ -15,65 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/momentum_op.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
__global__
void
MomentumKernel
(
const
T
*
p
,
const
T
*
g
,
const
T
*
v
,
const
T
*
learning_rate
,
const
T
mu
,
const
int64_t
num
,
bool
use_nesterov
,
T
*
p_out
,
T
*
v_out
)
{
T
lr
=
learning_rate
[
0
];
if
(
use_nesterov
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
num
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
g_val
=
g
[
i
];
T
v_new
=
v
[
i
]
*
mu
+
g_val
;
v_out
[
i
]
=
v_new
;
p_out
[
i
]
=
p
[
i
]
-
(
g_val
+
v_new
*
mu
)
*
lr
;
}
}
else
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
num
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
v_new
=
v
[
i
]
*
mu
+
g
[
i
];
v_out
[
i
]
=
v_new
;
p_out
[
i
]
=
p
[
i
]
-
lr
*
v_new
;
}
}
}
template
<
typename
T
>
class
MomentumOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
velocity_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"VelocityOut"
);
auto
param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
auto
velocity
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Velocity"
);
auto
grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
);
auto
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
T
*
p_out
=
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
v_out
=
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
mu
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"mu"
));
bool
use_nesterov
=
ctx
.
Attr
<
bool
>
(
"use_nesterov"
);
auto
*
p
=
param
->
data
<
T
>
();
auto
*
v
=
velocity
->
data
<
T
>
();
auto
*
g
=
grad
->
data
<
T
>
();
auto
*
lr
=
learning_rate
->
data
<
T
>
();
int
block
=
512
;
int
grid
=
(
param
->
numel
()
+
block
-
1
)
/
block
;
MomentumKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
p
,
g
,
v
,
lr
,
mu
,
param
->
numel
(),
use_nesterov
,
p_out
,
v_out
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
momentum
,
ops
::
MomentumOpCUDAKernel
<
float
>
,
ops
::
MomentumOpCUDAKernel
<
double
>
);
REGISTER_OP_CUDA_KERNEL
(
momentum
,
ops
::
MomentumOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
MomentumOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/momentum_op.h
浏览文件 @
b4a9c184
...
...
@@ -13,29 +13,48 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
MomentumOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
velocity_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"VelocityOut"
);
auto
param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
auto
velocity
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Velocity"
);
auto
grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
);
auto
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
using
framework
::
Tensor
;
using
framework
::
SelectedRows
;
struct
NoNesterov
;
struct
UseNesterov
;
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
template
<
typename
T
>
class
CPUDenseMomentumFunctor
{
private:
const
Tensor
*
param
;
const
Tensor
*
grad
;
const
Tensor
*
velocity
;
const
Tensor
*
learning_rate
;
const
T
mu
;
const
T
use_nesterov
;
Tensor
*
param_out
;
Tensor
*
velocity_out
;
T
mu
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"mu"
));
bool
use_nesterov
=
ctx
.
Attr
<
bool
>
(
"use_nesterov"
);
public:
CPUDenseMomentumFunctor
(
const
Tensor
*
param
,
const
Tensor
*
grad
,
const
Tensor
*
velocity
,
const
Tensor
*
learning_rate
,
const
T
mu
,
const
bool
use_nesterov
,
Tensor
*
param_out
,
Tensor
*
velocity_out
)
:
param
(
param
),
grad
(
grad
),
velocity
(
velocity
),
learning_rate
(
learning_rate
),
mu
(
mu
),
use_nesterov
(
use_nesterov
),
param_out
(
param_out
),
velocity_out
(
velocity_out
)
{}
inline
void
operator
()()
{
auto
p_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param_out
);
auto
v_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
velocity_out
);
...
...
@@ -53,5 +72,283 @@ class MomentumOpKernel : public framework::OpKernel<T> {
}
};
template
<
typename
T
,
typename
UpdateMethod
>
class
DenseMomentumFunctor
;
// NOTE(dzh) for performance.
// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two
// functor.
template
<
typename
T
>
class
DenseMomentumFunctor
<
T
,
UseNesterov
>
{
private:
const
T
*
p_
;
const
T
*
g_
;
const
T
*
v_
;
const
T
*
lr_
;
const
T
mu_
;
const
int64_t
num_
;
T
*
p_out_
;
T
*
v_out_
;
public:
DenseMomentumFunctor
(
const
T
*
p
,
const
T
*
g
,
const
T
*
v
,
const
T
*
learning_rate
,
const
T
mu
,
const
int64_t
num
,
T
*
p_out
,
T
*
v_out
)
:
p_
(
p
),
g_
(
g
),
v_
(
v
),
lr_
(
learning_rate
),
mu_
(
mu
),
num_
(
num
),
p_out_
(
p_out
),
v_out_
(
v_out
)
{}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
// put memory access in register
const
T
p
=
p_
[
i
];
const
T
g
=
g_
[
i
];
const
T
lr
=
lr_
[
0
];
const
T
v
=
v_
[
i
];
T
v_out
=
v
*
mu_
+
g
;
T
p_out
=
p
-
(
g
+
v_out
*
mu_
)
*
lr
;
// write reigster to memory
v_out_
[
i
]
=
v_out
;
p_out_
[
i
]
=
p_out
;
}
};
template
<
typename
T
>
class
DenseMomentumFunctor
<
T
,
NoNesterov
>
{
private:
const
T
*
p_
;
const
T
*
g_
;
const
T
*
v_
;
const
T
*
lr_
;
const
T
mu_
;
const
int64_t
num_
;
T
*
p_out_
;
T
*
v_out_
;
public:
DenseMomentumFunctor
(
const
T
*
p
,
const
T
*
g
,
const
T
*
v
,
const
T
*
learning_rate
,
const
T
mu
,
const
int64_t
num
,
T
*
p_out
,
T
*
v_out
)
:
p_
(
p
),
g_
(
g
),
v_
(
v
),
lr_
(
learning_rate
),
mu_
(
mu
),
num_
(
num
),
p_out_
(
p_out
),
v_out_
(
v_out
)
{}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
// put memory access in register
const
T
p
=
p_
[
i
];
const
T
g
=
g_
[
i
];
const
T
lr
=
lr_
[
0
];
const
T
v
=
v_
[
i
];
T
v_out
=
v
*
mu_
+
g
;
T
p_out
=
p
-
lr
*
v_out
;
// write reigster to memory
v_out_
[
i
]
=
v_out
;
p_out_
[
i
]
=
p_out
;
}
};
template
<
typename
T
,
typename
UpdateMethod
>
class
SparseMomentumFunctor
;
template
<
typename
T
>
class
SparseMomentumFunctor
<
T
,
UseNesterov
>
{
private:
const
T
*
p_
;
const
T
*
g_
;
const
T
*
v_
;
const
T
*
lr_
;
const
T
mu_
;
const
int64_t
*
rows_
;
const
int64_t
row_numel_
;
const
int64_t
row_height_
;
T
*
p_out_
;
T
*
v_out_
;
public:
SparseMomentumFunctor
(
const
T
*
p
,
const
T
*
g
,
const
T
*
v
,
const
T
*
lr
,
const
T
mu
,
const
int64_t
*
rows
,
int64_t
row_numel
,
int64_t
row_height
,
T
*
p_out
,
T
*
v_out
)
:
p_
(
p
),
g_
(
g
),
v_
(
v
),
lr_
(
lr
),
mu_
(
mu
),
rows_
(
rows
),
row_numel_
(
row_numel
),
row_height_
(
row_height
),
p_out_
(
p_out
),
v_out_
(
v_out
)
{}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
{
auto
row_idx
=
math
::
BinarySearch
<
int64_t
>
(
rows_
,
row_height_
,
i
/
row_numel_
);
T
g
=
row_idx
>=
0
?
g_
[
row_idx
*
row_numel_
+
i
%
row_numel_
]
:
0
;
// put memory access in register
const
T
p
=
p_
[
i
];
const
T
lr
=
lr_
[
0
];
const
T
v
=
v_
[
i
];
T
v_out
=
v
*
mu_
+
g
;
T
p_out
=
p
-
(
g
+
v_out
*
mu_
)
*
lr
;
// write reigster to memory
v_out_
[
i
]
=
v_out
;
p_out_
[
i
]
=
p_out
;
}
};
template
<
typename
T
>
class
SparseMomentumFunctor
<
T
,
NoNesterov
>
{
private:
const
T
*
p_
;
const
T
*
g_
;
const
T
*
v_
;
const
T
*
lr_
;
const
T
mu_
;
const
int64_t
*
rows_
;
const
int64_t
row_numel_
;
const
int64_t
row_height_
;
T
*
p_out_
;
T
*
v_out_
;
public:
SparseMomentumFunctor
(
const
T
*
p
,
const
T
*
g
,
const
T
*
v
,
const
T
*
lr
,
const
T
mu
,
const
int64_t
*
rows
,
int64_t
row_numel
,
int64_t
row_height
,
T
*
p_out
,
T
*
v_out
)
:
p_
(
p
),
g_
(
g
),
v_
(
v
),
lr_
(
lr
),
mu_
(
mu
),
rows_
(
rows
),
row_numel_
(
row_numel
),
row_height_
(
row_height
),
p_out_
(
p_out
),
v_out_
(
v_out
)
{}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
{
auto
row_idx
=
math
::
BinarySearch
<
int64_t
>
(
rows_
,
row_height_
,
i
/
row_numel_
);
T
g
=
row_idx
>=
0
?
g_
[
row_idx
*
row_numel_
+
i
%
row_numel_
]
:
0
;
// put memory access in register
const
T
p
=
p_
[
i
];
const
T
lr
=
lr_
[
0
];
const
T
v
=
v_
[
i
];
T
v_out
=
v
*
mu_
+
g
;
T
p_out
=
p
-
v_out
*
lr
;
// write reigster to memory
v_out_
[
i
]
=
v_out
;
p_out_
[
i
]
=
p_out
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
MomentumOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
T
mu
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"mu"
));
bool
use_nesterov
=
ctx
.
Attr
<
bool
>
(
"use_nesterov"
);
auto
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
auto
param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
auto
param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
*
velocity
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Velocity"
);
auto
velocity_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"VelocityOut"
);
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
if
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
);
if
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()))
{
CPUDenseMomentumFunctor
<
T
>
functor
(
param
,
grad
,
velocity
,
learning_rate
,
mu
,
use_nesterov
,
param_out
,
velocity_out
);
functor
();
}
else
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
platform
::
ForRange
<
DeviceContext
>
for_range
(
static_cast
<
const
DeviceContext
&>
(
ctx
.
device_context
()),
param
->
numel
());
if
(
use_nesterov
)
{
DenseMomentumFunctor
<
T
,
UseNesterov
>
functor
(
param
->
data
<
T
>
(),
grad
->
data
<
T
>
(),
velocity
->
data
<
T
>
(),
learning_rate
->
data
<
T
>
(),
mu
,
param
->
numel
(),
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
for_range
(
functor
);
}
else
{
DenseMomentumFunctor
<
T
,
NoNesterov
>
functor
(
param
->
data
<
T
>
(),
grad
->
data
<
T
>
(),
velocity
->
data
<
T
>
(),
learning_rate
->
data
<
T
>
(),
mu
,
param
->
numel
(),
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
for_range
(
functor
);
}
}
}
else
if
(
grad_var
->
IsType
<
framework
::
SelectedRows
>
())
{
// sparse update embedding with selectedrows
auto
grad
=
ctx
.
Input
<
framework
::
SelectedRows
>
(
"Grad"
);
// sparse update maybe empty.
if
(
grad
->
rows
().
size
()
==
0
)
{
VLOG
(
3
)
<<
"Grad SelectedRows contains no data!"
;
return
;
}
auto
*
merged_grad
=
const_cast
<
framework
::
Scope
&>
(
ctx
.
scope
())
.
Var
()
->
GetMutable
<
framework
::
SelectedRows
>
();
math
::
scatter
::
MergeAdd
<
DeviceContext
,
T
>
merge_func
;
merge_func
(
ctx
.
template
device_context
<
DeviceContext
>(),
*
grad
,
merged_grad
);
const
int64_t
*
rows
=
nullptr
;
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
rows
=
merged_grad
->
rows
().
CUDAData
(
ctx
.
GetPlace
());
}
else
{
#endif
rows
=
merged_grad
->
rows
().
data
();
#ifdef PADDLE_WITH_CUDA
}
#endif
int64_t
row_numel
=
merged_grad
->
value
().
numel
()
/
merged_grad
->
rows
().
size
();
platform
::
ForRange
<
DeviceContext
>
for_range
(
static_cast
<
const
DeviceContext
&>
(
ctx
.
device_context
()),
param
->
numel
());
if
(
use_nesterov
)
{
SparseMomentumFunctor
<
T
,
UseNesterov
>
functor
(
param
->
data
<
T
>
(),
merged_grad
->
value
().
data
<
T
>
(),
velocity
->
data
<
T
>
(),
learning_rate
->
data
<
T
>
(),
mu
,
rows
,
row_numel
,
static_cast
<
int64_t
>
(
merged_grad
->
rows
().
size
()),
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
for_range
(
functor
);
}
else
{
SparseMomentumFunctor
<
T
,
NoNesterov
>
functor
(
param
->
data
<
T
>
(),
merged_grad
->
value
().
data
<
T
>
(),
velocity
->
data
<
T
>
(),
learning_rate
->
data
<
T
>
(),
mu
,
rows
,
row_numel
,
static_cast
<
int64_t
>
(
merged_grad
->
rows
().
size
()),
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
for_range
(
functor
);
}
}
else
{
PADDLE_THROW
(
string
::
Sprintf
(
"MomentumOp only supports LoDTensor or SelectedRows "
"gradient, but the received Variable Type is %s"
,
grad_var
->
Type
().
name
()));
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/rmsprop_op.cc
浏览文件 @
b4a9c184
...
...
@@ -32,6 +32,11 @@ class RmspropOp : public framework::OperatorWithKernel {
"Input(Grad) of RmspropOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Moment"
),
"Input(Moment) of RmspropOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Param"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input var's type should be LoDTensor, but the received is %s"
,
ctx
->
Inputs
(
"Param"
).
front
(),
ctx
->
GetInputsVarType
(
"Param"
).
front
());
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ParamOut"
),
"Output(param_out) of RmspropOp should not be null."
);
...
...
paddle/fluid/operators/rmsprop_op.h
浏览文件 @
b4a9c184
...
...
@@ -132,6 +132,11 @@ class RmspropOpKernel : public framework::OpKernel<T> {
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
using
LoDTensor
=
framework
::
LoDTensor
;
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
PADDLE_ENFORCE
(
param_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Param"
).
front
(),
param_var
->
Type
().
name
());
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
auto
*
param_out
=
ctx
.
Output
<
LoDTensor
>
(
"ParamOut"
);
auto
*
moment_out
=
ctx
.
Output
<
LoDTensor
>
(
"MomentOut"
);
...
...
paddle/fluid/operators/sgd_op.cc
浏览文件 @
b4a9c184
...
...
@@ -21,7 +21,7 @@ class SGDOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of SGDOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
...
...
@@ -42,7 +42,7 @@ class SGDOp : public framework::OperatorWithKernel {
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
data_type
=
framework
::
GetDataTypeOfVar
(
ctx
.
InputVar
(
"Param"
));
return
framework
::
OpKernelType
(
data_type
,
ctx
.
device_context
());
}
...
...
@@ -50,17 +50,20 @@ class SGDOp : public framework::OperatorWithKernel {
class
SGDOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
input_var
=
op_desc
.
Input
(
"Param"
)[
0
];
for
(
auto
&
out_var
:
op_desc
.
Output
(
"ParamOut"
))
{
if
(
block
->
FindRecursiveOrCreateVar
(
input_var
).
GetType
()
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
block
->
FindRecursiveOrCreateVar
(
out_var
).
SetType
(
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
block
->
FindRecursiveOrCreateVar
(
out_var
).
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
input_var_n
=
op_desc
.
Input
(
"Param"
)[
0
];
auto
in_var_type
=
block
->
FindRecursiveOrCreateVar
(
input_var_n
).
GetType
();
PADDLE_ENFORCE
(
in_var_type
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
||
in_var_type
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input Var's type should be LoDtensor or SelectedRows,"
" but the received var(%s)'s type is %s"
,
input_var_n
,
in_var_type
);
for
(
auto
&
out_var_n
:
op_desc
.
Output
(
"ParamOut"
))
{
auto
&
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_n
);
if
(
out_var
.
GetType
()
!=
in_var_type
)
{
out_var
.
SetType
(
in_var_type
);
}
}
}
...
...
paddle/fluid/operators/sgd_op.cu
浏览文件 @
b4a9c184
...
...
@@ -57,6 +57,12 @@ template <typename T>
class
SGDOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
PADDLE_ENFORCE
(
param_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
Inputs
(
"Param"
).
front
(),
param_var
->
Type
().
name
());
auto
*
param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
auto
*
param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
*
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
...
...
python/paddle/fluid/framework.py
浏览文件 @
b4a9c184
...
...
@@ -1522,13 +1522,17 @@ class Program(object):
>>> with program.lr_schedule_guard():
>>> lr = lr * decay
"""
tmp_role
=
self
.
_current_role
tmp_var
=
self
.
_op_role_var
OpRole
=
core
.
op_proto_and_checker_maker
.
OpRole
self
.
_current_role
=
OpRole
.
LRSched
# TODO(typhoonzero): how to set target learning rate var
self
.
_op_role_var
=
[]
yield
self
.
_op_role_var
=
[]
self
.
_current_role
=
OpRole
.
Forward
self
.
_op_role_var
=
tmp_var
self
.
_current_role
=
tmp_role
def
__str__
(
self
):
"""
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
b4a9c184
...
...
@@ -15,7 +15,7 @@
from
__future__
import
print_function
import
re
from
collections
import
defaultdict
from
paddle.fluid.framework
import
Program
,
Variable
,
name_scope
from
paddle.fluid.framework
import
Program
,
Variable
,
name_scope
,
default_main_program
from
.
import
framework
from
.
import
layers
from
.backward
import
append_backward
...
...
@@ -111,7 +111,8 @@ class Optimizer(object):
if
param_lr
==
1.0
:
return
self
.
_global_learning_rate
()
else
:
return
self
.
_global_learning_rate
()
*
param_lr
with
default_main_program
().
_lr_schedule_guard
():
return
self
.
_global_learning_rate
()
*
param_lr
def
_create_accumulators
(
self
,
block
,
parameters
):
"""Create all accumulators needed by the parameters
...
...
@@ -659,6 +660,9 @@ class AdamaxOptimizer(Optimizer):
optimizer = fluid.optimizer.Adamax(learning_rate=0.2)
optimizer.minimize(cost)
Notes:
Currently, AdamaxOptimizer doesn't support sparse gradient.
"""
_moment_acc_str
=
"moment"
_inf_norm_acc_str
=
"inf_norm"
...
...
@@ -778,6 +782,9 @@ class DecayedAdagradOptimizer(Optimizer):
optimizer = fluid.optimizer.DecayedAdagrad(learning_rate=0.2)
optimizer.minimize(cost)
Notes:
Currently, DecayedAdagradOptimizer doesn't support sparse gradient.
"""
_moment_acc_str
=
"moment"
...
...
@@ -858,6 +865,9 @@ class AdadeltaOptimizer(Optimizer):
optimizer = fluid.optimizer.Adadelta(
learning_rate=0.0003, epsilon=1.0e-6, rho=0.95)
_, params_grads = optimizer.minimize(cost)
Notes:
Currently, AdadeltaOptimizer doesn't support sparse gradient.
"""
_avg_squared_grad_acc_str
=
"_avg_squared_grad"
...
...
@@ -1126,6 +1136,9 @@ class FtrlOptimizer(Optimizer):
optimizer = fluid.optimizer.Ftrl(0.0001)
_, params_grads = optimizer.minimize(cost)
Notes:
Currently, FtrlOptimizer doesn't support sparse gradient.
"""
_squared_acc_str
=
"squared"
...
...
python/paddle/fluid/tests/unittests/dist_simnet_bow.py
浏览文件 @
b4a9c184
...
...
@@ -81,7 +81,10 @@ def get_optimizer():
return
optimizer
def
train_network
(
batch_size
,
is_distributed
=
False
,
is_sparse
=
False
):
def
train_network
(
batch_size
,
is_distributed
=
False
,
is_sparse
=
False
,
is_self_contained_lr
=
False
):
# query
q
=
fluid
.
layers
.
data
(
name
=
"query_ids"
,
shape
=
[
1
],
dtype
=
"int64"
,
lod_level
=
1
)
...
...
@@ -93,7 +96,9 @@ def train_network(batch_size, is_distributed=False, is_sparse=False):
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.01
),
name
=
"__emb__"
,
learning_rate
=
emb_lr
),
learning_rate
=
emb_lr
)
if
is_self_contained_lr
else
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.01
),
name
=
"__emb__"
),
is_sparse
=
is_sparse
)
## vsum
q_sum
=
fluid
.
layers
.
sequence_pool
(
input
=
q_emb
,
pool_type
=
'sum'
)
...
...
@@ -119,7 +124,9 @@ def train_network(batch_size, is_distributed=False, is_sparse=False):
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.01
),
name
=
"__emb__"
,
learning_rate
=
emb_lr
),
learning_rate
=
emb_lr
)
if
is_self_contained_lr
else
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.01
),
name
=
"__emb__"
),
is_sparse
=
is_sparse
)
## vsum
pt_sum
=
fluid
.
layers
.
sequence_pool
(
input
=
pt_emb
,
pool_type
=
'sum'
)
...
...
@@ -144,7 +151,9 @@ def train_network(batch_size, is_distributed=False, is_sparse=False):
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.01
),
name
=
"__emb__"
,
learning_rate
=
emb_lr
),
learning_rate
=
emb_lr
)
if
is_self_contained_lr
else
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.01
),
name
=
"__emb__"
),
is_sparse
=
is_sparse
)
## vsum
nt_sum
=
fluid
.
layers
.
sequence_pool
(
input
=
nt_emb
,
pool_type
=
'sum'
)
...
...
@@ -220,7 +229,10 @@ class TestDistSimnetBow2x2(TestDistRunnerBase):
def
get_model
(
self
,
batch_size
=
2
):
# Train program
avg_cost
,
acc
,
predict
=
\
train_network
(
batch_size
,
bool
(
int
(
os
.
environ
[
"IS_DISTRIBUTED"
])),
bool
(
int
(
os
.
environ
[
"IS_SPARSE"
])))
train_network
(
batch_size
,
bool
(
int
(
os
.
environ
[
"IS_DISTRIBUTED"
])),
bool
(
int
(
os
.
environ
[
"IS_SPARSE"
])),
bool
(
int
(
os
.
environ
[
"IS_SELF_CONTAINED_LR"
])))
inference_program
=
fluid
.
default_main_program
().
clone
()
...
...
python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py
浏览文件 @
b4a9c184
...
...
@@ -25,7 +25,11 @@ class TestDistSimnetBowDense2x2(TestDistBase):
self
.
_enforce_place
=
"CPU"
def
test_simnet_bow
(
self
):
need_envs
=
{
"IS_DISTRIBUTED"
:
'0'
,
"IS_SPARSE"
:
'0'
}
need_envs
=
{
"IS_DISTRIBUTED"
:
'0'
,
"IS_SPARSE"
:
'0'
,
'IS_SELF_CONTAINED_LR'
:
'1'
}
self
.
check_with_place
(
"dist_simnet_bow.py"
,
delta
=
1e-5
,
...
...
@@ -39,7 +43,11 @@ class TestDistSimnetBow2x2DenseAsync(TestDistBase):
self
.
_enforce_place
=
"CPU"
def
test_simnet_bow
(
self
):
need_envs
=
{
"IS_DISTRIBUTED"
:
'0'
,
"IS_SPARSE"
:
'0'
}
need_envs
=
{
"IS_DISTRIBUTED"
:
'0'
,
"IS_SPARSE"
:
'0'
,
'IS_SELF_CONTAINED_LR'
:
'1'
}
self
.
check_with_place
(
"dist_simnet_bow.py"
,
delta
=
100
,
...
...
@@ -53,7 +61,11 @@ class TestDistSimnetBowSparse2x2(TestDistBase):
self
.
_enforce_place
=
"CPU"
def
test_simnet_bow
(
self
):
need_envs
=
{
"IS_DISTRIBUTED"
:
'0'
,
"IS_SPARSE"
:
'1'
}
need_envs
=
{
"IS_DISTRIBUTED"
:
'0'
,
"IS_SPARSE"
:
'1'
,
'IS_SELF_CONTAINED_LR'
:
'1'
}
self
.
check_with_place
(
"dist_simnet_bow.py"
,
delta
=
1e-5
,
...
...
@@ -67,7 +79,11 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase):
self
.
_enforce_place
=
"CPU"
def
test_simnet_bow
(
self
):
need_envs
=
{
"IS_DISTRIBUTED"
:
'0'
,
"IS_SPARSE"
:
'1'
}
need_envs
=
{
"IS_DISTRIBUTED"
:
'0'
,
"IS_SPARSE"
:
'1'
,
'IS_SELF_CONTAINED_LR'
:
'1'
}
self
.
check_with_place
(
"dist_simnet_bow.py"
,
delta
=
100
,
...
...
@@ -75,5 +91,59 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase):
need_envs
=
need_envs
)
class
TestDistSimnetBow2x2LookupTableSync
(
TestDistBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
True
self
.
_enforce_place
=
"CPU"
def
test_simnet_bow
(
self
):
need_envs
=
{
"IS_DISTRIBUTED"
:
'1'
,
"IS_SPARSE"
:
'1'
,
'IS_SELF_CONTAINED_LR'
:
'1'
}
self
.
check_with_place
(
"dist_simnet_bow.py"
,
delta
=
1e-5
,
check_error_log
=
False
,
need_envs
=
need_envs
)
class
TestDistSimnetBow2x2LookupTableAsync
(
TestDistBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
False
self
.
_enforce_place
=
"CPU"
def
test_simnet_bow
(
self
):
need_envs
=
{
"IS_DISTRIBUTED"
:
'1'
,
"IS_SPARSE"
:
'1'
,
'IS_SELF_CONTAINED_LR'
:
'1'
}
self
.
check_with_place
(
"dist_simnet_bow.py"
,
delta
=
100
,
check_error_log
=
False
,
need_envs
=
need_envs
)
class
TestDistSimnetBow2x2LookupTableNotContainLRSync
(
TestDistBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
True
self
.
_enforce_place
=
"CPU"
def
test_simnet_bow
(
self
):
need_envs
=
{
"IS_DISTRIBUTED"
:
'1'
,
"IS_SPARSE"
:
'1'
,
'IS_SELF_CONTAINED_LR'
:
'0'
}
self
.
check_with_place
(
"dist_simnet_bow.py"
,
delta
=
1e-5
,
check_error_log
=
False
,
need_envs
=
need_envs
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_momentum_op.py
浏览文件 @
b4a9c184
...
...
@@ -16,6 +16,8 @@ from __future__ import print_function
import
unittest
import
numpy
as
np
import
paddle.fluid.core
as
core
from
paddle.fluid.op
import
Operator
from
op_test
import
OpTest
...
...
@@ -88,5 +90,97 @@ class TestMomentumOp2(OpTest):
self
.
check_output
()
class
TestSparseMomentumOp
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
use_nesterov
=
False
def
check_with_place
(
self
,
place
):
self
.
init_kernel
()
scope
=
core
.
Scope
()
# create and initialize Grad Variable
height
=
10
rows
=
[
0
,
4
,
7
]
row_numel
=
12
mu
=
1.0
use_nesterov
=
self
.
use_nesterov
# create and initialize Param Variable
param
=
scope
.
var
(
'Param'
).
get_tensor
()
param_array
=
np
.
full
((
height
,
row_numel
),
5.0
).
astype
(
"float32"
)
param
.
set
(
param_array
,
place
)
param_out
=
scope
.
var
(
"ParamOut"
).
get_tensor
()
param_out_array
=
np
.
full
((
height
,
row_numel
),
0.0
).
astype
(
"float32"
)
param_out
.
set
(
param_out_array
,
place
)
grad_selected_rows
=
scope
.
var
(
'Grad'
).
get_selected_rows
()
grad_selected_rows
.
set_height
(
height
)
grad_selected_rows
.
set_rows
(
rows
)
grad_np_array
=
np
.
ones
((
len
(
rows
),
row_numel
)).
astype
(
"float32"
)
grad_np_array
[
0
,
0
]
=
2.0
grad_np_array
[
2
,
8
]
=
4.0
grad_tensor
=
grad_selected_rows
.
get_tensor
()
grad_tensor
.
set
(
grad_np_array
,
place
)
velocity
=
scope
.
var
(
'Velocity'
).
get_tensor
()
velocity_np_array
=
np
.
ones
((
height
,
row_numel
)).
astype
(
"float32"
)
velocity
.
set
(
velocity_np_array
,
place
)
velocity_out
=
scope
.
var
(
'VelocityOut'
).
get_tensor
()
velocity_out_np_array
=
np
.
full
((
height
,
row_numel
),
0.0
).
astype
(
"float32"
)
velocity_out
.
set
(
velocity_out_np_array
,
place
)
# create and initialize LeraningRate Variable
lr
=
scope
.
var
(
'LearningRate'
).
get_tensor
()
lr_array
=
np
.
full
((
1
),
2.0
).
astype
(
"float32"
)
lr
.
set
(
lr_array
,
place
)
# create and run operator
op
=
Operator
(
"momentum"
,
Param
=
'Param'
,
Grad
=
'Grad'
,
Velocity
=
'Velocity'
,
ParamOut
=
'ParamOut'
,
VelocityOut
=
'VelocityOut'
,
LearningRate
=
'LearningRate'
,
mu
=
mu
,
use_nesterov
=
use_nesterov
)
op
.
run
(
scope
,
place
)
# get and compare result
param_out_np_array
=
np
.
array
(
param_out
)
velocity_out_np_array
=
np
.
array
(
velocity_out
)
# TODO(dzh): add a more suitable general numpy interface
# for sparse update.
_grad_np_array
=
np
.
full
((
height
,
row_numel
),
0.0
).
astype
(
"float32"
)
for
i
in
range
(
len
(
rows
)):
_grad_np_array
[
rows
[
i
]]
=
grad_np_array
[
i
]
_velocity_out
=
mu
*
velocity_np_array
+
_grad_np_array
_param
=
param_array
if
use_nesterov
:
_param_out
=
_param
-
(
_grad_np_array
+
_velocity_out
*
mu
)
*
lr_array
else
:
_param_out
=
_param
-
lr_array
*
_velocity_out
self
.
assertTrue
((
_velocity_out
==
velocity_out_np_array
).
all
())
self
.
assertTrue
((
_param_out
==
param_out_np_array
).
all
())
def
init_kernel
(
self
):
pass
def
test_sparse_momentum
(
self
):
places
=
[
core
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
core
.
CUDAPlace
(
0
))
for
place
in
places
:
self
.
check_with_place
(
place
)
class
TestSparseMomentumOp2
(
TestSparseMomentumOp
):
def
init_kernel
(
self
):
self
.
use_nesterov
=
True
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
b4a9c184
...
...
@@ -1118,6 +1118,7 @@ to transpile() call.")
def
_split_table_grad_and_add_send_vars
(
self
,
program
,
pserver_endpoints
):
# 2. add split_ids_op and send_op to send gradient to pservers
# there should only be one table_name
all_ops
=
program
.
global_block
().
ops
table_grad_name
=
grad_var_name
(
self
.
table_name
)
...
...
@@ -1142,7 +1143,7 @@ to transpile() call.")
if
self
.
sync_mode
else
[]
},
attrs
=
{
"sync_mode"
:
self
.
sync_mode
,
"sync_mode"
:
not
self
.
sync_mode
,
"epmap"
:
pserver_endpoints
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
,
OP_ROLE_VAR_ATTR_NAME
:
[
...
...
@@ -1188,7 +1189,15 @@ to transpile() call.")
def
_create_table_optimize_block
(
self
,
pserver_index
,
pserver_program
,
pre_block_idx
,
grad_to_block_id
):
# STEP: create table optimize block
table_opt_block
=
pserver_program
.
_create_block
(
pre_block_idx
)
# create table param and grad var in pserver program
# create table optimize block in pserver program
table_opt_op
=
[
op
for
op
in
self
.
optimize_ops
if
'Param'
in
op
.
input_names
and
op
.
input
(
"Param"
)[
0
]
==
self
.
table_name
][
0
]
origin_param_var
=
self
.
origin_program
.
global_block
().
vars
[
self
.
table_name
]
...
...
@@ -1204,19 +1213,16 @@ to transpile() call.")
dtype
=
origin_param_var
.
dtype
,
type
=
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
,
persistable
=
True
)
# parameter must be selected rows
param_var
.
desc
.
set_type
(
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
)
grad_var
=
pserver_program
.
global_block
().
_clone_variable
(
self
.
origin_program
.
global_block
().
vars
[
grad_var_name
(
self
.
table_name
)])
# create table optimize block in pserver program
table_opt_op
=
[
op
for
op
in
self
.
optimize_ops
if
'Param'
in
op
.
input_names
and
op
.
input
(
"Param"
)[
0
]
==
self
.
table_name
][
0
]
table_opt_block
=
pserver_program
.
_create_block
(
pre_block_idx
)
lr_var
=
pserver_program
.
global_block
().
_clone_variable
(
self
.
origin_program
.
global_block
().
vars
[
table_opt_op
.
input
(
"LearningRate"
)[
0
]])
if
self
.
sync_mode
:
# create grad vars in pserver program
...
...
@@ -1248,8 +1254,6 @@ to transpile() call.")
grad_var
=
pserver_program
.
global_block
().
_rename_var
(
origin_grad_name
,
splited_grad_name
)
lr_var
=
pserver_program
.
global_block
().
vars
[
table_opt_op
.
input
(
"LearningRate"
)[
0
]]
inputs
=
{
"Param"
:
[
param_var
],
"Grad"
:
[
grad_var
],
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录