Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
57bbee65
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
57bbee65
编写于
11月 16, 2017
作者:
Q
qingqing01
提交者:
GitHub
11月 16, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' into cmake_speed
上级
0968c7cd
d7bf372d
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
386 addition
and
38 deletion
+386
-38
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+7
-2
paddle/operators/adagrad_op.cc
paddle/operators/adagrad_op.cc
+85
-5
paddle/operators/adagrad_op.cu
paddle/operators/adagrad_op.cu
+133
-2
paddle/operators/adagrad_op.h
paddle/operators/adagrad_op.h
+45
-21
paddle/operators/sgd_op.cu
paddle/operators/sgd_op.cu
+8
-7
paddle/operators/sum_op.cc
paddle/operators/sum_op.cc
+0
-1
python/paddle/v2/fluid/tests/test_adagrad_op.py
python/paddle/v2/fluid/tests/test_adagrad_op.py
+108
-0
未找到文件。
paddle/operators/CMakeLists.txt
浏览文件 @
57bbee65
...
...
@@ -183,15 +183,20 @@ set(DEPS_OPS
array_to_lod_tensor_op
lstm_op
tensor_array_read_write_op
gru_op
)
gru_op
adagrad_op
sgd_op
)
op_library
(
cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op
)
op_library
(
cross_entropy_op DEPS cross_entropy
)
op_library
(
softmax_with_cross_entropy_op DEPS cross_entropy softmax
)
op_library
(
softmax_op DEPS softmax
)
op_library
(
sequence_softmax_op DEPS softmax
)
op_library
(
sum_op DEPS selected_rows_functor
)
op_library
(
sgd_op DEPS selected_rows_functor
)
op_library
(
adagrad_op DEPS selected_rows_functor
)
op_library
(
conv_op DEPS vol2col
)
op_library
(
sum_op DEPS net_op selected_rows_functor
)
op_library
(
pool_op DEPS pooling
)
op_library
(
pool_with_index_op DEPS pooling
)
op_library
(
lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table
)
...
...
paddle/operators/adagrad_op.cc
浏览文件 @
57bbee65
...
...
@@ -14,6 +14,11 @@ limitations under the License. */
#include "paddle/operators/adagrad_op.h"
#include <cmath>
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/selected_rows_functor.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -21,7 +26,7 @@ class AdagradOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of AdagradOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
...
...
@@ -54,8 +59,8 @@ class AdagradOp : public framework::OperatorWithKernel {
class
AdagradOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
AdagradOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
AdagradOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"Param"
,
"(Tensor) Input parameter"
);
AddInput
(
"Grad"
,
"(Tensor) Input gradient"
);
...
...
@@ -87,10 +92,85 @@ for numerical stability to avoid the division by zero error.
)DOC"
);
}
};
namespace
{
size_t
FindPos
(
const
std
::
vector
<
int64_t
>&
rows
,
int64_t
value
)
{
return
std
::
find
(
rows
.
begin
(),
rows
.
end
(),
value
)
-
rows
.
begin
();
}
}
// namespace
template
<
typename
T
>
struct
SparseAdagradFunctor
<
platform
::
CPUPlace
,
T
>
{
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
grad
,
const
framework
::
Tensor
&
learning_rate
,
T
epsilon
,
framework
::
Tensor
*
moment
,
framework
::
Tensor
*
param
)
{
// 1. g_m.rows = set(g.rows)
auto
grad_rows
=
grad
.
rows
();
std
::
set
<
int64_t
>
row_set
(
grad_rows
.
begin
(),
grad_rows
.
end
());
std
::
vector
<
int64_t
>
merge_rows
(
row_set
.
begin
(),
row_set
.
end
());
auto
grad_width
=
grad
.
value
().
dims
()[
1
];
std
::
unique_ptr
<
framework
::
SelectedRows
>
grad_merge
{
new
framework
::
SelectedRows
()};
grad_merge
->
set_rows
(
merge_rows
);
grad_merge
->
set_height
(
grad
.
height
());
grad_merge
->
mutable_value
()
->
mutable_data
<
T
>
(
framework
::
make_ddim
(
{
static_cast
<
int64_t
>
(
merge_rows
.
size
()),
grad_width
}),
context
.
GetPlace
());
math
::
SetConstant
<
platform
::
CPUPlace
,
T
>
constant_functor
;
constant_functor
(
context
,
grad_merge
->
mutable_value
(),
0.0
);
auto
*
grad_merge_data
=
grad_merge
->
mutable_value
()
->
data
<
T
>
();
auto
*
grad_data
=
grad
.
value
().
data
<
T
>
();
for
(
size_t
i
=
0
;
i
<
grad_rows
.
size
();
i
++
)
{
size_t
grad_merge_i
=
FindPos
(
merge_rows
,
grad_rows
[
i
]);
for
(
int64_t
j
=
0
;
j
<
grad_width
;
j
++
)
{
grad_merge_data
[
grad_merge_i
*
grad_width
+
j
]
+=
grad_data
[
i
*
grad_width
+
j
];
}
}
// 2. m += g_m * g_m
std
::
unique_ptr
<
framework
::
SelectedRows
>
grad_square
{
new
framework
::
SelectedRows
()};
grad_square
->
set_rows
(
grad_merge
->
rows
());
grad_square
->
set_height
(
grad_merge
->
height
());
grad_square
->
mutable_value
()
->
mutable_data
<
T
>
(
grad_merge
->
value
().
dims
(),
context
.
GetPlace
());
auto
gs
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
(
grad_square
->
mutable_value
()));
auto
gm
=
framework
::
EigenVector
<
T
>::
Flatten
(
grad_merge
->
value
());
gs
.
device
(
*
context
.
GetEigenDevice
<
platform
::
CPUPlace
>
())
=
gm
*
gm
;
math
::
SelectedRowsAddToTensor
<
platform
::
CPUPlace
,
T
>
functor
;
functor
(
context
,
*
grad_square
,
moment
);
// 3. update parameter
auto
*
lr
=
learning_rate
.
data
<
T
>
();
auto
*
param_data
=
param
->
data
<
T
>
();
auto
*
moment_data
=
moment
->
data
<
T
>
();
for
(
size_t
i
=
0
;
i
<
merge_rows
.
size
();
i
++
)
{
for
(
int64_t
j
=
0
;
j
<
grad_width
;
j
++
)
{
param_data
[
merge_rows
[
i
]
*
grad_width
+
j
]
-=
lr
[
0
]
*
grad_merge_data
[
i
*
grad_width
+
j
]
/
(
std
::
sqrt
(
moment_data
[
merge_rows
[
i
]
*
grad_width
+
j
])
+
epsilon
);
}
}
}
};
template
struct
SparseAdagradFunctor
<
platform
::
CPUPlace
,
float
>;
template
struct
SparseAdagradFunctor
<
platform
::
CPUPlace
,
double
>;
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
adagrad
,
ops
::
AdagradOp
,
ops
::
AdagradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
adagrad
,
ops
::
AdagradOpKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
adagrad
,
ops
::
AdagradOpKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
AdagradOpKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
);
paddle/operators/adagrad_op.cu
浏览文件 @
57bbee65
...
...
@@ -14,7 +14,138 @@
#define EIGEN_USE_GPU
#include "paddle/operators/adagrad_op.h"
#include "paddle/operators/math/selected_rows_functor.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/platform/cuda_helper.h"
namespace
paddle
{
namespace
operators
{
namespace
{
template
<
typename
T
,
int
block_size
>
__global__
void
MergeGradKernel
(
const
T
*
grad
,
const
int64_t
*
grad_rows
,
T
*
grad_merge
,
const
int64_t
*
grad_merge_rows
,
size_t
grad_merge_rows_size
,
int64_t
row_numel
)
{
const
int
ty
=
blockIdx
.
y
;
int
tid
=
threadIdx
.
x
;
__shared__
size_t
grad_merge_idx
;
if
(
tid
==
0
)
{
for
(
size_t
i
=
0
;
i
<
grad_merge_rows_size
;
i
++
)
{
if
(
grad_rows
[
ty
]
==
grad_merge_rows
[
i
])
{
grad_merge_idx
=
i
;
}
}
}
__syncthreads
();
grad
+=
ty
*
row_numel
;
grad_merge
+=
grad_merge_idx
*
row_numel
;
for
(
int
index
=
tid
;
index
<
row_numel
;
index
+=
block_size
)
{
paddle
::
platform
::
CudaAtomicAdd
(
grad_merge
+
index
,
grad
[
index
]);
}
}
template
<
typename
T
,
int
block_size
>
__global__
void
SparseAdagradFunctorKernel
(
const
T
*
grad
,
const
int64_t
*
rows
,
const
T
*
learning_rate
,
T
*
param
,
T
*
moment
,
int64_t
row_numel
,
T
epsilon
)
{
const
int
ty
=
blockIdx
.
y
;
int
tid
=
threadIdx
.
x
;
grad
+=
ty
*
row_numel
;
param
+=
rows
[
ty
]
*
row_numel
;
moment
+=
rows
[
ty
]
*
row_numel
;
for
(
int
index
=
tid
;
index
<
row_numel
;
index
+=
block_size
)
{
// Since index in rows of SelectedRows can be duplicate, we have to use
// Atomic Operation to avoid concurrent write error.
paddle
::
platform
::
CudaAtomicAdd
(
param
+
index
,
-
1.0
*
learning_rate
[
0
]
*
grad
[
index
]
/
(
sqrt
(
moment
[
index
])
+
epsilon
));
}
}
}
// namespace
template
<
typename
T
>
struct
SparseAdagradFunctor
<
platform
::
GPUPlace
,
T
>
{
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
grad
,
const
framework
::
Tensor
&
learning_rate
,
T
epsilon
,
framework
::
Tensor
*
moment
,
framework
::
Tensor
*
param
)
{
// 1. g_m.rows = set(g.rows)
auto
grad_rows
=
grad
.
rows
();
std
::
set
<
int64_t
>
row_set
(
grad_rows
.
begin
(),
grad_rows
.
end
());
std
::
vector
<
int64_t
>
merge_rows
(
row_set
.
begin
(),
row_set
.
end
());
auto
grad_width
=
grad
.
value
().
dims
()[
1
];
std
::
unique_ptr
<
framework
::
SelectedRows
>
grad_merge
{
new
framework
::
SelectedRows
()};
grad_merge
->
set_rows
(
merge_rows
);
grad_merge
->
set_height
(
grad
.
height
());
grad_merge
->
mutable_value
()
->
mutable_data
<
T
>
(
framework
::
make_ddim
(
{
static_cast
<
int64_t
>
(
merge_rows
.
size
()),
grad_width
}),
context
.
GetPlace
());
math
::
SetConstant
<
platform
::
GPUPlace
,
T
>
constant_functor
;
constant_functor
(
context
,
grad_merge
->
mutable_value
(),
0.0
);
auto
*
grad_merge_data
=
grad_merge
->
mutable_value
()
->
data
<
T
>
();
auto
*
grad_data
=
grad
.
value
().
data
<
T
>
();
const
int
block_size
=
256
;
dim3
threads
(
block_size
,
1
);
dim3
grid1
(
1
,
grad_rows
.
size
());
MergeGradKernel
<
T
,
256
><<<
grid1
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
grad_data
,
grad
.
rows
().
data
(),
grad_merge_data
,
grad_merge
->
rows
().
data
(),
grad_merge
->
rows
().
size
(),
grad_width
);
// 2. m += g_m * g_m
std
::
unique_ptr
<
framework
::
SelectedRows
>
grad_square
{
new
framework
::
SelectedRows
()};
grad_square
->
set_rows
(
grad_merge
->
rows
());
grad_square
->
set_height
(
grad_merge
->
height
());
grad_square
->
mutable_value
()
->
mutable_data
<
T
>
(
grad_merge
->
value
().
dims
(),
context
.
GetPlace
());
auto
gs
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
(
grad_square
->
mutable_value
()));
auto
gm
=
framework
::
EigenVector
<
T
>::
Flatten
(
grad_merge
->
value
());
gs
.
device
(
*
context
.
GetEigenDevice
<
platform
::
GPUPlace
>
())
=
gm
*
gm
;
math
::
SelectedRowsAddToTensor
<
platform
::
GPUPlace
,
T
>
functor
;
functor
(
context
,
*
grad_square
,
moment
);
// 3. update parameter
auto
*
lr
=
learning_rate
.
data
<
T
>
();
auto
*
param_data
=
param
->
data
<
T
>
();
auto
*
moment_data
=
moment
->
data
<
T
>
();
dim3
grid2
(
1
,
merge_rows
.
size
());
SparseAdagradFunctorKernel
<
T
,
256
><<<
grid2
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
grad_merge_data
,
grad_merge
->
rows
().
data
(),
lr
,
param_data
,
moment_data
,
grad_width
,
epsilon
);
}
};
template
struct
SparseAdagradFunctor
<
platform
::
GPUPlace
,
float
>;
template
struct
SparseAdagradFunctor
<
platform
::
GPUPlace
,
double
>;
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
adagrad
,
ops
::
AdagradOpKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
adagrad
,
ops
::
AdagradOpKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
,
ops
::
AdagradOpKernel
<
paddle
::
platform
::
GPUPlace
,
double
>
);
paddle/operators/adagrad_op.h
浏览文件 @
57bbee65
...
...
@@ -19,35 +19,59 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
template
<
typename
Place
,
typename
T
>
struct
SparseAdagradFunctor
{
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
grad
,
const
framework
::
Tensor
&
learning_rate
,
T
epsilon
,
framework
::
Tensor
*
moment
,
framework
::
Tensor
*
param
);
};
template
<
typename
Place
,
typename
T
>
class
AdagradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
param_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
moment_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"MomentOut"
);
auto
*
param_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
*
moment_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"MomentOut"
);
param_out_tensor
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
moment_out_tensor
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
float
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
auto
param
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
));
auto
grad
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
));
auto
moment
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
framework
::
Tensor
>
(
"Moment"
));
auto
lr
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
));
auto
param_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param_out_tensor
);
auto
moment_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
moment_out_tensor
);
auto
place
=
ctx
.
GetEigenDevice
<
Place
>
();
moment_out
.
device
(
place
)
=
moment
+
grad
*
grad
;
Eigen
::
DSizes
<
int
,
1
>
m_dsize
(
moment_out_tensor
->
numel
());
param_out
.
device
(
place
)
=
param
-
lr
.
broadcast
(
m_dsize
)
*
grad
/
(
moment_out
.
sqrt
()
+
epsilon
);
T
epsilon
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"epsilon"
));
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
if
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
param
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
));
auto
grad
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
));
auto
moment
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
framework
::
Tensor
>
(
"Moment"
));
auto
lr
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
));
auto
param_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param_out_tensor
);
auto
moment_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
moment_out_tensor
);
auto
place
=
ctx
.
GetEigenDevice
<
Place
>
();
moment_out
.
device
(
place
)
=
moment
+
grad
*
grad
;
Eigen
::
DSizes
<
int
,
1
>
m_dsize
(
moment_out_tensor
->
numel
());
param_out
.
device
(
place
)
=
param
-
lr
.
broadcast
(
m_dsize
)
*
grad
/
(
moment_out
.
sqrt
()
+
epsilon
);
}
else
if
(
grad_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
*
param_tensor
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
PADDLE_ENFORCE_EQ
(
param_tensor
,
param_out_tensor
);
auto
*
moment_tensor
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Moment"
);
PADDLE_ENFORCE_EQ
(
moment_tensor
,
moment_out_tensor
);
SparseAdagradFunctor
<
Place
,
T
>
functor
;
functor
(
ctx
.
device_context
(),
*
ctx
.
Input
<
framework
::
SelectedRows
>
(
"Grad"
),
*
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
),
epsilon
,
moment_out_tensor
,
param_out_tensor
);
}
else
{
PADDLE_THROW
(
"Unsupported Variable Type of Grad"
);
}
}
};
...
...
paddle/operators/sgd_op.cu
浏览文件 @
57bbee65
...
...
@@ -20,11 +20,11 @@ namespace paddle {
namespace
operators
{
namespace
{
template
<
typename
T
>
template
<
typename
T
,
int
block_size
>
__global__
void
SparseSGDFunctorKernel
(
const
T
*
selected_rows
,
const
int64_t
*
rows
,
const
T
*
learning_rate
,
T
*
tensor_out
,
int64_t
row_numel
,
int
block_size
)
{
int64_t
row_numel
)
{
const
int
ty
=
blockIdx
.
y
;
int
tid
=
threadIdx
.
x
;
...
...
@@ -59,14 +59,15 @@ struct SparseSGDFunctor<platform::GPUPlace, T> {
auto
*
in_data
=
in_value
.
data
<
T
>
();
auto
*
out_data
=
output
->
data
<
T
>
();
int
block_size
=
256
;
const
int
block_size
=
256
;
dim3
threads
(
block_size
,
1
);
dim3
grid
(
1
,
in_rows
.
size
());
SparseSGDFunctorKernel
<
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
in_data
,
in_rows
.
data
(),
learning_rate
.
data
<
T
>
(),
out_data
,
in_row_numel
,
block_size
);
T
,
256
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
in_data
,
in_rows
.
data
(),
learning_rate
.
data
<
T
>
(),
out_data
,
in_row_numel
);
}
};
...
...
paddle/operators/sum_op.cc
浏览文件 @
57bbee65
...
...
@@ -12,7 +12,6 @@ limitations under the License. */
#include "paddle/operators/sum_op.h"
#include <vector>
#include "paddle/framework/var_type_inference.h"
#include "paddle/operators/net_op.h"
namespace
paddle
{
namespace
operators
{
...
...
python/paddle/v2/fluid/tests/test_adagrad_op.py
浏览文件 @
57bbee65
import
unittest
import
numpy
as
np
import
paddle.v2.fluid.core
as
core
from
paddle.v2.fluid.op
import
Operator
from
op_test
import
OpTest
import
math
class
TestAdagradOp1
(
OpTest
):
...
...
@@ -65,5 +68,110 @@ class TestAdagradOp2(OpTest):
self
.
check_output
()
class
TestSparseAdagradOp
(
unittest
.
TestCase
):
def
check_with_place
(
self
,
place
):
scope
=
core
.
Scope
()
# create and initialize Grad Variable
height
=
10
rows
=
[
0
,
4
,
7
,
4
]
row_numel
=
12
grad_selected_rows
=
scope
.
var
(
'Grad'
).
get_selected_rows
()
grad_selected_rows
.
set_height
(
height
)
grad_selected_rows
.
set_rows
(
rows
)
np_array
=
np
.
ones
((
len
(
rows
),
row_numel
)).
astype
(
"float32"
)
np_array
[
0
,
0
]
=
2.0
np_array
[
2
,
8
]
=
4.0
grad_tensor
=
grad_selected_rows
.
get_tensor
()
grad_tensor
.
set
(
np_array
,
place
)
# create and initialize Param Variable
param
=
scope
.
var
(
'Param'
).
get_tensor
()
param_array
=
np
.
full
((
height
,
row_numel
),
5.0
).
astype
(
"float32"
)
param
.
set
(
param_array
,
place
)
# create and initialize LeraningRate Variable
lr
=
scope
.
var
(
'LearningRate'
).
get_tensor
()
lr_array
=
np
.
full
((
1
),
2.0
).
astype
(
"float32"
)
lr
.
set
(
lr_array
,
place
)
# create and initialize moment Variable
moment
=
scope
.
var
(
'Moment'
).
get_tensor
()
moment_np_array
=
np
.
full
((
height
,
row_numel
),
2.0
).
astype
(
"float32"
)
moment
.
set
(
moment_np_array
,
place
)
# create and run sgd operator
adagrad_op
=
Operator
(
"adagrad"
,
Param
=
'Param'
,
Grad
=
'Grad'
,
ParamOut
=
'Param'
,
Moment
=
'Moment'
,
MomentOut
=
'Moment'
,
LearningRate
=
'LearningRate'
,
epsilon
=
2.0
)
ctx
=
core
.
DeviceContext
.
create
(
place
)
adagrad_op
.
run
(
scope
,
ctx
)
# get and compare moment result
moment_result_array
=
np
.
array
(
moment
)
self
.
assertAlmostEqual
(
6.0
,
moment_result_array
[
rows
[
0
],
0
])
self
.
assertAlmostEqual
(
3.0
,
moment_result_array
[
rows
[
0
],
2
])
self
.
assertAlmostEqual
(
2.0
,
moment_result_array
[
1
,
0
])
# 2.0 + (1.0 + 1.0)^2
self
.
assertAlmostEqual
(
6.0
,
moment_result_array
[
rows
[
1
],
10
])
self
.
assertAlmostEqual
(
6.0
,
moment_result_array
[
rows
[
3
],
4
])
self
.
assertAlmostEqual
(
2.0
,
moment_result_array
[
5
,
8
])
self
.
assertAlmostEqual
(
3.0
,
moment_result_array
[
rows
[
2
],
1
])
self
.
assertAlmostEqual
(
18.0
,
moment_result_array
[
rows
[
2
],
8
])
# get and compare param result
result_array
=
np
.
array
(
param
)
def
get_out
(
param
,
lr
,
grad
,
m
,
epsilon
):
return
param
-
lr
*
grad
/
(
math
.
sqrt
(
m
)
+
epsilon
)
self
.
assertAlmostEqual
(
get_out
(
5.0
,
2.0
,
2.0
,
6.0
,
2.0
),
result_array
[
rows
[
0
],
0
],
places
=
5
)
self
.
assertAlmostEqual
(
get_out
(
5.0
,
2.0
,
1.0
,
3.0
,
2.0
),
result_array
[
rows
[
0
],
2
],
places
=
5
)
self
.
assertAlmostEqual
(
get_out
(
5.0
,
2.0
,
0.0
,
2.0
,
2.0
),
result_array
[
1
,
0
],
places
=
5
)
# grad_merge = 1.0 + 1.0
# m = 6.0
self
.
assertAlmostEqual
(
get_out
(
5.0
,
2.0
,
2.0
,
6.0
,
2.0
),
result_array
[
rows
[
1
],
10
],
places
=
5
)
self
.
assertAlmostEqual
(
get_out
(
5.0
,
2.0
,
0.0
,
2.0
,
2.0
),
result_array
[
5
,
8
],
places
=
5
)
self
.
assertAlmostEqual
(
get_out
(
5.0
,
2.0
,
1.0
,
3.0
,
2.0
),
result_array
[
rows
[
2
],
1
],
places
=
5
)
self
.
assertAlmostEqual
(
get_out
(
5.0
,
2.0
,
4.0
,
18.0
,
2.0
),
result_array
[
rows
[
2
],
8
],
places
=
5
)
def
test_sparse_adagrad
(
self
):
places
=
[
core
.
CPUPlace
()]
if
core
.
is_compile_gpu
():
places
.
append
(
core
.
GPUPlace
(
0
))
for
place
in
places
:
self
.
check_with_place
(
place
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录