Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3b1294ae
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3b1294ae
编写于
6月 06, 2017
作者:
D
dzhwinter
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"add checkpoint interface: set state, get state"
上级
fd8c5107
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
121 addition
and
17 deletion
+121
-17
paddle/optimizer/optimizer.cc
paddle/optimizer/optimizer.cc
+12
-1
paddle/optimizer/optimizer.h
paddle/optimizer/optimizer.h
+5
-1
paddle/optimizer/parameter_optimizer.h
paddle/optimizer/parameter_optimizer.h
+4
-0
paddle/optimizer/parameter_optimizer_test.cpp
paddle/optimizer/parameter_optimizer_test.cpp
+15
-15
paddle/optimizer/serialization.h
paddle/optimizer/serialization.h
+36
-0
paddle/optimizer/sgd_optimizer.h
paddle/optimizer/sgd_optimizer.h
+2
-0
paddle/optimizer/sgd_optmizer.cc
paddle/optimizer/sgd_optmizer.cc
+27
-0
proto/OptimizerConfig.proto
proto/OptimizerConfig.proto
+20
-0
未找到文件。
paddle/optimizer/optimizer.cc
浏览文件 @
3b1294ae
...
@@ -34,10 +34,16 @@ struct paddle_optimizer {
...
@@ -34,10 +34,16 @@ struct paddle_optimizer {
};
};
paddle_optimizer
*
paddle_create_optimizer
(
const
unsigned
char
*
config_proto
,
paddle_optimizer
*
paddle_create_optimizer
(
const
unsigned
char
*
config_proto
,
int
config_proto_len
)
{
const
int
config_proto_len
,
const
char
**
state
,
const
int
state_size
)
{
paddle_optimizer
*
optimizer
=
new
paddle_optimizer
;
paddle_optimizer
*
optimizer
=
new
paddle_optimizer
;
std
::
string
config
(
config_proto
,
config_proto
+
config_proto_len
);
std
::
string
config
(
config_proto
,
config_proto
+
config_proto_len
);
optimizer
->
impl
=
ParameterOptimizer
::
Create
(
config
);
optimizer
->
impl
=
ParameterOptimizer
::
Create
(
config
);
if
(
state
!=
nullptr
)
{
std
::
string
s
(
*
state
,
*
state
+
state_size
);
optimizer
->
impl
->
DeSerializeState
(
s
);
}
return
optimizer
;
return
optimizer
;
}
}
...
@@ -71,3 +77,8 @@ void* paddle_optimizer_get_weights(paddle_optimizer* o) {
...
@@ -71,3 +77,8 @@ void* paddle_optimizer_get_weights(paddle_optimizer* o) {
void
*
buffer
=
(
void
*
)
o
->
impl
->
get_weight
();
void
*
buffer
=
(
void
*
)
o
->
impl
->
get_weight
();
return
buffer
;
return
buffer
;
}
}
int
paddle_optimizer_get_state
(
paddle_optimizer
*
o
,
const
char
*
state
)
{
state
=
o
->
impl
->
SerializeState
();
return
PADDLE_SUCCESS
;
}
paddle/optimizer/optimizer.h
浏览文件 @
3b1294ae
...
@@ -45,7 +45,9 @@ typedef struct paddle_optimizer paddle_optimizer;
...
@@ -45,7 +45,9 @@ typedef struct paddle_optimizer paddle_optimizer;
* @return return optimizer instance
* @return return optimizer instance
*/
*/
paddle_optimizer
*
paddle_create_optimizer
(
const
unsigned
char
*
config_proto
,
paddle_optimizer
*
paddle_create_optimizer
(
const
unsigned
char
*
config_proto
,
int
config_proto_len
);
const
int
config_proto_len
,
const
char
**
state
,
const
int
state_size
);
/**
/**
* @brief release optimizer
* @brief release optimizer
...
@@ -86,6 +88,8 @@ int paddle_optimizer_set_weights(paddle_optimizer* o,
...
@@ -86,6 +88,8 @@ int paddle_optimizer_set_weights(paddle_optimizer* o,
*/
*/
void
*
paddle_optimizer_get_weights
(
paddle_optimizer
*
o
);
void
*
paddle_optimizer_get_weights
(
paddle_optimizer
*
o
);
int
paddle_optimizer_get_state
(
paddle_optimizer
*
o
,
const
char
*
state
);
#ifdef __cplusplus
#ifdef __cplusplus
}
}
#endif
#endif
...
...
paddle/optimizer/parameter_optimizer.h
浏览文件 @
3b1294ae
...
@@ -11,6 +11,8 @@
...
@@ -11,6 +11,8 @@
namespace
paddle
{
namespace
paddle
{
namespace
optimizer
{
namespace
optimizer
{
const
std
::
string
kOptimizerVersion
=
"1.0"
;
class
ParameterOptimizer
{
class
ParameterOptimizer
{
public:
public:
/**
/**
...
@@ -21,6 +23,8 @@ public:
...
@@ -21,6 +23,8 @@ public:
virtual
~
ParameterOptimizer
()
{
delete
parameter_
;
};
virtual
~
ParameterOptimizer
()
{
delete
parameter_
;
};
static
ParameterOptimizer
*
Create
(
const
std
::
string
&
config_proto
);
static
ParameterOptimizer
*
Create
(
const
std
::
string
&
config_proto
);
virtual
const
char
*
SerializeState
();
virtual
void
DeSerializeState
(
const
std
::
string
&
state
);
virtual
void
Update
(
const
Tensor
*
gradient
)
=
0
;
virtual
void
Update
(
const
Tensor
*
gradient
)
=
0
;
virtual
real
*
get_weight
()
const
;
virtual
real
*
get_weight
()
const
;
virtual
void
set_weight
(
Tensor
*
parameter
);
virtual
void
set_weight
(
Tensor
*
parameter
);
...
...
paddle/optimizer/parameter_optimizer_test.cpp
浏览文件 @
3b1294ae
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
using
namespace
paddle
;
using
namespace
paddle
;
using
namespace
paddle
::
optimizer
;
using
namespace
paddle
::
optimizer
;
Tensor
*
fill_n_
Tensor
(
size_t
size
)
{
Tensor
*
Fill
Tensor
(
size_t
size
)
{
real
*
ptr
=
new
real
[
size
];
real
*
ptr
=
new
real
[
size
];
Tensor
*
param
=
new
Tensor
(
ptr
,
size
);
Tensor
*
param
=
new
Tensor
(
ptr
,
size
);
Tensor
&
p
=
*
param
;
Tensor
&
p
=
*
param
;
...
@@ -20,7 +20,7 @@ Tensor* fill_n_Tensor(size_t size) {
...
@@ -20,7 +20,7 @@ Tensor* fill_n_Tensor(size_t size) {
return
param
;
return
param
;
}
}
Tensor
*
fix_n_
Tensor
(
size_t
size
)
{
Tensor
*
Fixed
Tensor
(
size_t
size
)
{
real
*
ptr
=
new
real
[
size
];
real
*
ptr
=
new
real
[
size
];
Tensor
*
param
=
new
Tensor
(
ptr
,
size
);
Tensor
*
param
=
new
Tensor
(
ptr
,
size
);
Tensor
&
p
=
*
param
;
Tensor
&
p
=
*
param
;
...
@@ -36,12 +36,12 @@ public:
...
@@ -36,12 +36,12 @@ public:
const
size_t
size
=
5
;
const
size_t
size
=
5
;
virtual
void
SetUp
()
{
virtual
void
SetUp
()
{
create_sgd
();
CreateSGD
();
create_a
dam
();
CreateA
dam
();
}
}
virtual
void
TearDown
()
{}
virtual
void
TearDown
()
{}
void
create_sgd
()
{
void
CreateSGD
()
{
config
.
set_optimizer
(
OptimizerConfig
::
SGD
);
config
.
set_optimizer
(
OptimizerConfig
::
SGD
);
config
.
mutable_sgd
()
->
set_momentum
(
0.0
);
config
.
mutable_sgd
()
->
set_momentum
(
0.0
);
config
.
mutable_sgd
()
->
set_decay
(
0.0
);
config
.
mutable_sgd
()
->
set_decay
(
0.0
);
...
@@ -54,7 +54,7 @@ public:
...
@@ -54,7 +54,7 @@ public:
opts
.
push_back
(
opt
);
opts
.
push_back
(
opt
);
}
}
void
create_a
dam
()
{
void
CreateA
dam
()
{
config
.
set_optimizer
(
OptimizerConfig
::
Adam
);
config
.
set_optimizer
(
OptimizerConfig
::
Adam
);
config
.
mutable_adam
()
->
set_beta_1
(
0.9
);
config
.
mutable_adam
()
->
set_beta_1
(
0.9
);
config
.
mutable_adam
()
->
set_beta_2
(
0.1
);
config
.
mutable_adam
()
->
set_beta_2
(
0.1
);
...
@@ -66,15 +66,15 @@ public:
...
@@ -66,15 +66,15 @@ public:
ParameterOptimizer
::
Create
(
config
.
SerializeAsString
());
ParameterOptimizer
::
Create
(
config
.
SerializeAsString
());
opts
.
push_back
(
opt
);
opts
.
push_back
(
opt
);
}
}
void
test_set_w
eight
()
{
void
TestSetW
eight
()
{
Tensor
*
p
=
fill_n_
Tensor
(
size
);
Tensor
*
p
=
Fill
Tensor
(
size
);
for
(
size_t
i
=
0
;
i
<
opts
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
opts
.
size
();
++
i
)
{
opts
[
i
]
->
set_weight
(
p
);
opts
[
i
]
->
set_weight
(
p
);
}
}
}
}
void
test_get_w
eight
()
{
void
TestGetW
eight
()
{
Tensor
*
p
=
fix_n_
Tensor
(
size
);
Tensor
*
p
=
Fixed
Tensor
(
size
);
for
(
size_t
i
=
0
;
i
<
opts
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
opts
.
size
();
++
i
)
{
opts
[
i
]
->
set_weight
(
p
);
opts
[
i
]
->
set_weight
(
p
);
}
}
...
@@ -85,8 +85,8 @@ public:
...
@@ -85,8 +85,8 @@ public:
}
}
}
}
}
}
void
test_u
pdate
()
{
void
TestU
pdate
()
{
Tensor
*
g
=
fix_n_
Tensor
(
size
);
Tensor
*
g
=
Fixed
Tensor
(
size
);
for
(
size_t
i
=
0
;
i
<
opts
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
opts
.
size
();
++
i
)
{
opts
[
i
]
->
Update
(
g
);
opts
[
i
]
->
Update
(
g
);
}
}
...
@@ -98,10 +98,10 @@ private:
...
@@ -98,10 +98,10 @@ private:
};
};
TEST_F
(
OptimizerTest
,
test_set_get_weight
)
{
TEST_F
(
OptimizerTest
,
test_set_get_weight
)
{
test_set_w
eight
();
TestSetW
eight
();
test_get_w
eight
();
TestGetW
eight
();
}
}
TEST_F
(
OptimizerTest
,
test_update
)
{
test_u
pdate
();
}
TEST_F
(
OptimizerTest
,
TestUpdate
)
{
TestU
pdate
();
}
int
main
(
int
argc
,
char
**
argv
)
{
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
testing
::
InitGoogleTest
(
&
argc
,
argv
);
...
...
paddle/optimizer/serialization.h
0 → 100644
浏览文件 @
3b1294ae
#ifndef PADDLE_OPTIMIZER_SERIALIZARION_H
#define PADDLE_OPTIMIZER_SERIALIZARION_H
#include <sstream>
#include <string>
#include "OptimizerConfig.pb.h"
#include "paddle/utils/Logging.h"
#include "tensor.h"
namespace
paddle
{
namespace
optimizer
{
static
void
TensorToProto
(
const
Tensor
&
tensor
,
TensorProto
*
proto
)
{
proto
->
set_data_type
(
TensorProto
::
PADDLE_ELEMENT_TYPE_FLOAT32
);
proto
->
set_size
(
tensor
.
size
());
std
::
stringstream
os
;
for
(
size_t
i
=
0
;
i
<
tensor
.
size
();
++
i
)
{
os
<<
tensor
[
i
];
proto
->
add_content
(
os
.
str
());
os
.
clear
();
}
}
static
void
ProtoToTensor
(
const
TensorProto
&
proto
,
Tensor
*
tensor
)
{
CHECK
(
proto
.
size
()
==
tensor
->
size
())
<<
"unmatch shape of proto and tensor"
;
std
::
stringstream
sin
;
for
(
auto
i
=
0
;
i
<
proto
.
content_size
();
++
i
)
{
sin
<<
proto
.
content
(
i
);
sin
>>
(
*
tensor
)[
i
];
sin
.
clear
();
}
}
}
// namespace optimizer
}
// namespace paddle
#endif
paddle/optimizer/sgd_optimizer.h
浏览文件 @
3b1294ae
...
@@ -12,6 +12,8 @@ public:
...
@@ -12,6 +12,8 @@ public:
:
ParameterOptimizer
(
lr
),
momentum_
(
m
),
decay_
(
d
),
nesterov_
(
n
)
{}
:
ParameterOptimizer
(
lr
),
momentum_
(
m
),
decay_
(
d
),
nesterov_
(
n
)
{}
virtual
~
SGDOptimizer
()
{
delete
momentums_
;
}
virtual
~
SGDOptimizer
()
{
delete
momentums_
;
}
void
Update
(
const
Tensor
*
gradient
);
void
Update
(
const
Tensor
*
gradient
);
const
char
*
SerializeState
();
void
DeSerializeState
(
const
std
::
string
&
state
);
void
set_weight
(
Tensor
*
p
);
void
set_weight
(
Tensor
*
p
);
real
*
get_weight
()
const
;
real
*
get_weight
()
const
;
...
...
paddle/optimizer/sgd_optmizer.cc
浏览文件 @
3b1294ae
#include "serialization.h"
#include "sgd_optimizer.h"
#include "sgd_optimizer.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -37,5 +38,31 @@ void SGDOptimizer::Update(const Tensor *gradient) {
...
@@ -37,5 +38,31 @@ void SGDOptimizer::Update(const Tensor *gradient) {
}
}
}
}
const
char
*
SGDOptimizer
::
SerializeState
()
{
OptimizerState
state
;
// version is a global const value
state
.
set_version
(
kOptimizerVersion
);
TensorToProto
(
*
parameter_
,
state
.
add_data
());
TensorToProto
(
*
momentums_
,
state
.
add_data
());
// state.add_data(param_proto);
// state.add_data(momentum_proto);
state
.
add_hyperparam
(
momentum_
);
return
state
.
SerializeAsString
().
c_str
();
}
void
SGDOptimizer
::
DeSerializeState
(
const
std
::
string
&
str
)
{
OptimizerState
state
;
state
.
ParseFromString
(
str
);
CHECK
(
state
.
version
()
==
kOptimizerVersion
)
<<
"error version of state"
<<
"expected : "
<<
kOptimizerVersion
<<
"get : "
<<
state
.
version
();
ProtoToTensor
(
state
.
data
(
0
),
parameter_
);
if
(
state
.
data_size
()
==
2
)
{
ProtoToTensor
(
state
.
data
(
1
),
momentums_
);
momentum_
=
state
.
hyperparam
(
0
);
}
}
}
// namespace optimizer
}
// namespace optimizer
}
// namespace paddle
}
// namespace paddle
proto/OptimizerConfig.proto
浏览文件 @
3b1294ae
...
@@ -64,6 +64,26 @@ message LinearLr {
...
@@ -64,6 +64,26 @@ message LinearLr {
optional
double
lr_decay_b
=
3
;
optional
double
lr_decay_b
=
3
;
}
}
message
TensorProto
{
enum
DataType
{
PADDLE_ELEMENT_TYPE_INT32
=
0
;
PADDLE_ELEMENT_TYPE_UINT32
=
1
;
PADDLE_ELEMENT_TYPE_INT64
=
2
;
PADDLE_ELEMENT_TYPE_UINT64
=
3
;
PADDLE_ELEMENT_TYPE_FLOAT32
=
4
;
PADDLE_ELEMENT_TYPE_FLOAT64
=
5
;
}
required
DataType
data_type
=
1
;
repeated
bytes
content
=
2
;
optional
uint64
size
=
3
;
}
message
OptimizerState
{
// match old training state with format parser
required
string
version
=
100
;
repeated
TensorProto
data
=
1
;
repeated
double
hyperparam
=
3
;
}
message
OptimizerConfig
{
message
OptimizerConfig
{
// common config of optimizer
// common config of optimizer
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录