Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
8329a1f1
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8329a1f1
编写于
10月 14, 2018
作者:
D
dzhwinter
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add sparse update momentum. test=develop
上级
ce248a15
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
263 addition
and
42 deletion
+263
-42
paddle/fluid/operators/momentum_op.cc
paddle/fluid/operators/momentum_op.cc
+24
-5
paddle/fluid/operators/momentum_op.cu
paddle/fluid/operators/momentum_op.cu
+77
-17
paddle/fluid/operators/momentum_op.h
paddle/fluid/operators/momentum_op.h
+62
-20
python/paddle/fluid/tests/unittests/test_momentum_op.py
python/paddle/fluid/tests/unittests/test_momentum_op.py
+100
-0
未找到文件。
paddle/fluid/operators/momentum_op.cc
浏览文件 @
8329a1f1
...
@@ -24,7 +24,7 @@ class MomentumOp : public framework::OperatorWithKernel {
...
@@ -24,7 +24,7 @@ class MomentumOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(param) of Momentum should not be null."
);
"Input(param) of Momentum should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
...
@@ -53,13 +53,30 @@ class MomentumOp : public framework::OperatorWithKernel {
...
@@ -53,13 +53,30 @@ class MomentumOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"VelocityOut"
,
param_dim
);
ctx
->
SetOutputDim
(
"VelocityOut"
,
param_dim
);
}
}
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
auto
input_data_type
=
framework
::
GetDataTypeOfVar
(
ctx
.
InputVar
(
"Param"
));
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Param"
)
->
type
());
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
}
};
};
class
MomentumOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
input_var
=
op_desc
.
Input
(
"Param"
)[
0
];
for
(
auto
&
out_var
:
op_desc
.
Output
(
"ParamOut"
))
{
if
(
block
->
FindRecursiveOrCreateVar
(
input_var
).
GetType
()
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
block
->
FindRecursiveOrCreateVar
(
out_var
).
SetType
(
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
block
->
FindRecursiveOrCreateVar
(
out_var
).
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
}
};
class
MomentumOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
MomentumOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
void
Make
()
override
{
void
Make
()
override
{
...
@@ -110,6 +127,8 @@ $$
...
@@ -110,6 +127,8 @@ $$
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
momentum
,
ops
::
MomentumOp
,
ops
::
MomentumOpMaker
);
REGISTER_OPERATOR
(
momentum
,
ops
::
MomentumOp
,
ops
::
MomentumOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
MomentumOpInferVarType
);
REGISTER_OP_CPU_KERNEL
(
momentum
,
ops
::
MomentumOpKernel
<
float
>
,
REGISTER_OP_CPU_KERNEL
(
momentum
,
ops
::
MomentumOpKernel
<
float
>
,
ops
::
MomentumOpKernel
<
double
>
);
ops
::
MomentumOpKernel
<
double
>
);
paddle/fluid/operators/momentum_op.cu
浏览文件 @
8329a1f1
...
@@ -42,32 +42,92 @@ __global__ void MomentumKernel(const T* p, const T* g, const T* v,
...
@@ -42,32 +42,92 @@ __global__ void MomentumKernel(const T* p, const T* g, const T* v,
}
}
}
}
template
<
typename
T
>
__global__
void
SparseMomentumKernel
(
const
T
*
p
,
const
T
*
g
,
const
T
*
v
,
const
T
*
lr
,
const
T
mu
,
const
int64_t
*
grad_rows
,
const
size_t
grad_row_numel
,
const
size_t
grad_row_size
,
const
T
use_nesterov
,
T
*
p_out
,
T
*
v_out
)
{
for
(
int
i
=
blockIdx
.
x
;
i
<
grad_row_size
;
i
+=
gridDim
.
x
)
{
for
(
int
j
=
threadIdx
.
x
;
j
<
grad_row_numel
;
j
+=
blockDim
.
x
)
{
size_t
p_i
=
grad_rows
[
i
]
*
grad_row_numel
+
j
;
size_t
g_i
=
i
*
grad_row_numel
+
j
;
v_out
[
g_i
]
=
v
[
g_i
]
*
mu
+
g
[
g_i
];
if
(
use_nesterov
)
{
p_out
[
p_i
]
=
p
[
p_i
]
-
(
g
[
g_i
]
+
v_out
[
g_i
]
*
mu
)
*
lr
[
0
];
}
else
{
p_out
[
p_i
]
=
p
[
p_i
]
-
v_out
[
g_i
]
*
lr
[
0
];
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
class
MomentumOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
class
MomentumOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
T
mu
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"mu"
));
auto
velocity_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"VelocityOut"
);
bool
use_nesterov
=
ctx
.
Attr
<
bool
>
(
"use_nesterov"
);
auto
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
auto
param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
auto
param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
auto
param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
*
velocity_var
=
ctx
.
InputVar
(
"Velocity"
);
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
if
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
())
{
PADDLE_ENFORCE
(
velocity_var
->
IsType
<
framework
::
LoDTensor
>
(),
"Unmatched Type of Param and Grad"
);
auto
velocity
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Velocity"
);
auto
velocity
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Velocity"
);
auto
grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
);
auto
grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
);
auto
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
auto
velocity_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"VelocityOut"
);
T
*
p_out
=
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
p_out
=
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
v_out
=
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
v_out
=
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
mu
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"mu"
));
bool
use_nesterov
=
ctx
.
Attr
<
bool
>
(
"use_nesterov"
);
auto
*
p
=
param
->
data
<
T
>
();
auto
*
p
=
param
->
data
<
T
>
();
auto
*
v
=
velocity
->
data
<
T
>
();
auto
*
v
=
velocity
->
data
<
T
>
();
auto
*
g
=
grad
->
data
<
T
>
();
auto
*
g
=
grad
->
data
<
T
>
();
auto
*
lr
=
learning_rate
->
data
<
T
>
();
auto
*
lr
=
learning_rate
->
data
<
T
>
();
int
block
=
512
;
const
int
kThreadPerBlock
=
256
;
int
grid
=
(
param
->
numel
()
+
block
-
1
)
/
block
;
int
grid
=
(
param
->
numel
()
+
kThreadPerBlock
-
1
)
/
kThreadPerBlock
;
MomentumKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
MomentumKernel
<
T
><<<
grid
,
kThreadPerBlock
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
p
,
g
,
v
,
lr
,
mu
,
param
->
numel
(),
use_nesterov
,
p_out
,
v_out
);
p
,
g
,
v
,
lr
,
mu
,
param
->
numel
(),
use_nesterov
,
p_out
,
v_out
);
}
else
if
(
grad_var
->
IsType
<
framework
::
SelectedRows
>
())
{
// sparse update embedding with selectedrows
PADDLE_ENFORCE
(
velocity_var
->
IsType
<
framework
::
SelectedRows
>
(),
"Unmatched Type of Param and Grad"
);
auto
velocity
=
ctx
.
Input
<
framework
::
SelectedRows
>
(
"Velocity"
);
auto
grad
=
ctx
.
Input
<
framework
::
SelectedRows
>
(
"Grad"
);
auto
velocity_out
=
ctx
.
Output
<
framework
::
SelectedRows
>
(
"VelocityOut"
);
// sparse update maybe empty.
if
(
grad
->
rows
().
size
()
==
0
)
{
return
;
}
PADDLE_ENFORCE
(
grad
->
height
()
==
velocity
->
height
(),
"Unmatched gradient and velocity."
);
auto
*
p_out
=
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
v_out
=
velocity_out
->
mutable_value
()
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
lr
=
learning_rate
->
data
<
T
>
();
auto
*
p
=
param
->
data
<
T
>
();
auto
*
g
=
grad
->
value
().
data
<
T
>
();
auto
*
v
=
velocity
->
value
().
data
<
T
>
();
size_t
grad_row_numel
=
grad
->
value
().
numel
()
/
grad
->
rows
().
size
();
size_t
grad_row_size
=
grad
->
rows
().
size
();
framework
::
Vector
<
int64_t
>
rows
(
grad
->
rows
());
const
int
kThreadPerBlock
=
256
;
int
grid
=
(
param
->
numel
()
+
kThreadPerBlock
-
1
)
/
kThreadPerBlock
;
SparseMomentumKernel
<
T
><<<
grid
,
kThreadPerBlock
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
p
,
g
,
v
,
lr
,
mu
,
rows
.
CUDAData
(
ctx
.
GetPlace
()),
grad_row_numel
,
grad
->
rows
().
size
(),
use_nesterov
,
p_out
,
v_out
);
}
else
{
PADDLE_THROW
(
"Unsupported Variable Type of Grad"
);
}
}
}
};
};
...
...
paddle/fluid/operators/momentum_op.h
浏览文件 @
8329a1f1
...
@@ -23,19 +23,22 @@ template <typename T>
...
@@ -23,19 +23,22 @@ template <typename T>
class
MomentumOpKernel
:
public
framework
::
OpKernel
<
T
>
{
class
MomentumOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
T
mu
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"mu"
));
auto
velocity_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"VelocityOut"
);
bool
use_nesterov
=
ctx
.
Attr
<
bool
>
(
"use_nesterov"
);
auto
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
auto
param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
auto
param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
auto
param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
*
velocity_var
=
ctx
.
InputVar
(
"Velocity"
);
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
if
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
())
{
PADDLE_ENFORCE
(
velocity_var
->
IsType
<
framework
::
LoDTensor
>
(),
"Unmatched Type of Param and Grad"
);
auto
velocity
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Velocity"
);
auto
velocity
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Velocity"
);
auto
grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
);
auto
grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
);
auto
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
auto
velocity_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"VelocityOut"
);
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
mu
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"mu"
));
bool
use_nesterov
=
ctx
.
Attr
<
bool
>
(
"use_nesterov"
);
auto
p_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param_out
);
auto
p_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param_out
);
auto
v_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
velocity_out
);
auto
v_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
velocity_out
);
...
@@ -50,6 +53,45 @@ class MomentumOpKernel : public framework::OpKernel<T> {
...
@@ -50,6 +53,45 @@ class MomentumOpKernel : public framework::OpKernel<T> {
}
else
{
}
else
{
p_out
=
p
-
lr
[
0
]
*
v_out
;
p_out
=
p
-
lr
[
0
]
*
v_out
;
}
}
}
else
if
(
grad_var
->
IsType
<
framework
::
SelectedRows
>
())
{
// sparse update embedding with selectedrows
PADDLE_ENFORCE
(
velocity_var
->
IsType
<
framework
::
SelectedRows
>
(),
"Unmatched Type of Param and Grad"
);
auto
velocity
=
ctx
.
Input
<
framework
::
SelectedRows
>
(
"Velocity"
);
auto
grad
=
ctx
.
Input
<
framework
::
SelectedRows
>
(
"Grad"
);
auto
velocity_out
=
ctx
.
Output
<
framework
::
SelectedRows
>
(
"VelocityOut"
);
// sparse update maybe empty.
if
(
grad
->
rows
().
size
()
==
0
)
{
return
;
}
PADDLE_ENFORCE
(
grad
->
height
()
==
velocity
->
height
(),
"Unmatched gradient and velocity."
);
auto
*
p_out
=
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
v_out
=
velocity_out
->
mutable_value
()
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
lr
=
learning_rate
->
data
<
T
>
();
auto
*
p
=
param
->
data
<
T
>
();
auto
*
g
=
grad
->
value
().
data
<
T
>
();
auto
*
v
=
velocity
->
value
().
data
<
T
>
();
size_t
grad_row_numel
=
grad
->
value
().
numel
()
/
grad
->
rows
().
size
();
for
(
size_t
i
=
0
;
i
<
grad
->
rows
().
size
();
++
i
)
{
size_t
grad_row_index
=
grad
->
rows
()[
i
];
for
(
size_t
j
=
0
;
j
<
grad_row_numel
;
++
j
)
{
size_t
p_i
=
grad_row_index
*
grad_row_numel
+
j
;
size_t
g_i
=
i
*
grad_row_numel
+
j
;
v_out
[
g_i
]
=
v
[
g_i
]
*
mu
+
g
[
g_i
];
if
(
use_nesterov
)
{
p_out
[
p_i
]
=
p
[
p_i
]
-
(
g
[
g_i
]
+
v_out
[
g_i
]
*
mu
)
*
lr
[
0
];
}
else
{
p_out
[
p_i
]
=
p
[
p_i
]
-
v_out
[
g_i
]
*
lr
[
0
];
}
}
}
}
else
{
PADDLE_THROW
(
"Unsupported Variable Type of Grad"
);
}
}
}
};
};
...
...
python/paddle/fluid/tests/unittests/test_momentum_op.py
浏览文件 @
8329a1f1
...
@@ -16,6 +16,8 @@ from __future__ import print_function
...
@@ -16,6 +16,8 @@ from __future__ import print_function
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid.core
as
core
from
paddle.fluid.op
import
Operator
from
op_test
import
OpTest
from
op_test
import
OpTest
...
@@ -88,5 +90,103 @@ class TestMomentumOp2(OpTest):
...
@@ -88,5 +90,103 @@ class TestMomentumOp2(OpTest):
self
.
check_output
()
self
.
check_output
()
class
TestSparseMomentumOp
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
use_nesterov
=
False
def
check_with_place
(
self
,
place
):
self
.
init_kernel
()
scope
=
core
.
Scope
()
# create and initialize Grad Variable
height
=
10
rows
=
[
0
,
4
,
7
]
row_numel
=
12
mu
=
1.0
use_nesterov
=
self
.
use_nesterov
# create and initialize Param Variable
param
=
scope
.
var
(
'Param'
).
get_tensor
()
param_array
=
np
.
full
((
height
,
row_numel
),
5.0
).
astype
(
"float32"
)
param
.
set
(
param_array
,
place
)
param_out
=
scope
.
var
(
"ParamOut"
).
get_tensor
()
param_out_array
=
np
.
full
((
height
,
row_numel
),
0.0
).
astype
(
"float32"
)
param_out
.
set
(
param_out_array
,
place
)
grad_selected_rows
=
scope
.
var
(
'Grad'
).
get_selected_rows
()
grad_selected_rows
.
set_height
(
height
)
grad_selected_rows
.
set_rows
(
rows
)
grad_np_array
=
np
.
ones
((
len
(
rows
),
row_numel
)).
astype
(
"float32"
)
grad_np_array
[
0
,
0
]
=
2.0
grad_np_array
[
2
,
8
]
=
4.0
grad_tensor
=
grad_selected_rows
.
get_tensor
()
grad_tensor
.
set
(
grad_np_array
,
place
)
velocity_selected_rows
=
scope
.
var
(
'Velocity'
).
get_selected_rows
()
velocity_selected_rows
.
set_height
(
height
)
velocity_selected_rows
.
set_rows
(
rows
)
velocity_np_array
=
np
.
ones
((
len
(
rows
),
row_numel
)).
astype
(
"float32"
)
velocity_np_array
[
0
,
0
]
=
2.0
velocity_np_array
[
2
,
8
]
=
2.0
velocity_tensor
=
velocity_selected_rows
.
get_tensor
()
velocity_tensor
.
set
(
velocity_np_array
,
place
)
velocity_out_selected_rows
=
scope
.
var
(
'VelocityOut'
).
get_selected_rows
(
)
velocity_out_selected_rows
.
set_height
(
height
)
velocity_out_selected_rows
.
set_rows
(
rows
)
velocity_out_np_array
=
np
.
full
((
len
(
rows
),
row_numel
),
0.0
).
astype
(
"float32"
)
velocity_out_tensor
=
velocity_out_selected_rows
.
get_tensor
()
velocity_out_tensor
.
set
(
velocity_out_np_array
,
place
)
# create and initialize LeraningRate Variable
lr
=
scope
.
var
(
'LearningRate'
).
get_tensor
()
lr_array
=
np
.
full
((
1
),
2.0
).
astype
(
"float32"
)
lr
.
set
(
lr_array
,
place
)
# create and run operator
op
=
Operator
(
"momentum"
,
Param
=
'Param'
,
Grad
=
'Grad'
,
Velocity
=
'Velocity'
,
ParamOut
=
'ParamOut'
,
VelocityOut
=
'VelocityOut'
,
LearningRate
=
'LearningRate'
,
mu
=
mu
,
use_nesterov
=
use_nesterov
)
op
.
run
(
scope
,
place
)
# get and compare result
param_out_np_array
=
np
.
array
(
param_out
)
velocity_out_np_array
=
np
.
array
(
velocity_out_tensor
)
# TODO(dzh): add a more suitable general numpy interface
# for sparse update.
_velocity_out
=
mu
*
velocity_np_array
+
grad_np_array
_param
=
param_array
[
rows
]
if
use_nesterov
:
_param_out
=
_param
-
grad_np_array
*
lr_array
-
\
_velocity_out
*
mu
*
lr_array
else
:
_param_out
=
_param
-
lr
*
_velocity_out
self
.
assertTrue
((
_param_out
==
param_out_np_array
[
rows
]).
all
())
self
.
assertTrue
((
_velocity_out
==
velocity_out_np_array
).
all
())
def
init_kernel
(
self
):
pass
def
test_sparse_momentum
(
self
):
places
=
[
core
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
core
.
CUDAPlace
(
0
))
for
place
in
places
:
self
.
check_with_place
(
place
)
class
TestSparseMomentumOp2
(
TestSparseMomentumOp
):
def
init_kernel
(
self
):
self
.
use_nesterov
=
True
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录