Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ec65fa83
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ec65fa83
编写于
6月 19, 2017
作者:
D
dzhwinter
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"protobuf required to optional"
上级
65d9e33b
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
7 addition
and
16 deletion
+7
-16
paddle/optimizer/adam_optimizer.cc
paddle/optimizer/adam_optimizer.cc
+1
-1
paddle/optimizer/parameter_optimizer_test.cpp
paddle/optimizer/parameter_optimizer_test.cpp
+0
-9
paddle/optimizer/sgd_optimizer.cc
paddle/optimizer/sgd_optimizer.cc
+1
-1
proto/OptimizerConfig.proto
proto/OptimizerConfig.proto
+5
-5
未找到文件。
paddle/optimizer/adam_optimizer.cc
浏览文件 @
ec65fa83
...
...
@@ -26,7 +26,6 @@ const char *AdamOptimizer::SerializeState(int *state_len) {
AdamOptimizerState
state
;
// TODO(zhihong) : add lr_policy serialization
state
.
set_num_sample_passed
(
num_sample_passed_
);
TensorToProto
(
*
parameter_
,
state
.
mutable_parameter
());
TensorToProto
(
*
momentums_
,
state
.
mutable_momentums
());
TensorToProto
(
*
velocitys_
,
state
.
mutable_velocitys
());
...
...
@@ -42,6 +41,7 @@ void AdamOptimizer::DeserializeState(const std::string &str) {
num_sample_passed_
=
state
.
num_sample_passed
();
ProtoToTensor
(
state
.
parameter
(),
parameter_
);
ProtoToTensor
(
state
.
momentums
(),
momentums_
);
ProtoToTensor
(
state
.
velocitys
(),
velocitys_
);
}
}
// namespace optimizer
...
...
paddle/optimizer/parameter_optimizer_test.cpp
浏览文件 @
ec65fa83
...
...
@@ -45,11 +45,9 @@ public:
config_
.
mutable_sgd
()
->
set_nesterov
(
false
);
config_
.
set_lr_policy
(
OptimizerConfig
::
Const
);
config_
.
mutable_const_lr
()
->
set_learning_rate
(
0.1
);
std
::
string
str
=
config_
.
SerializeAsString
();
ParameterOptimizer
*
opt
=
ParameterOptimizer
::
Create
(
str
,
parameter
);
opts_
.
push_back
(
opt
);
opts_table_
[
opts_
.
size
()]
=
OptimizerConfig
::
SGD
;
}
void
CreateAdam
()
{
...
...
@@ -64,7 +62,6 @@ public:
std
::
string
str
=
config_
.
SerializeAsString
();
ParameterOptimizer
*
opt
=
ParameterOptimizer
::
Create
(
str
,
parameter
);
opts_
.
push_back
(
opt
);
opts_table_
[
opts_
.
size
()]
=
OptimizerConfig
::
Adam
;
}
void
TestGetWeight
()
{
...
...
@@ -86,21 +83,15 @@ public:
}
void
TestCheckPoint
()
{
std
::
map
<
OptimizerConfig
::
Optimizer
,
int
>
expected_state_len
=
{
{
OptimizerConfig
::
SGD
,
kSize
*
sizeof
(
float
)
+
sizeof
(
double
)},
{
OptimizerConfig
::
Adam
,
kSize
*
3
*
sizeof
(
float
)
+
sizeof
(
double
)},
};
for
(
size_t
i
=
0
;
i
<
opts_
.
size
();
++
i
)
{
int
state_len
=
0
;
std
::
string
state
=
opts_
[
i
]
->
SerializeState
(
&
state_len
);
EXPECT_EQ
(
state_len
,
expected_state_len
[
opts_table_
[
i
+
1
]]);
opts_
[
i
]
->
DeserializeState
(
state
);
}
}
private:
std
::
vector
<
ParameterOptimizer
*>
opts_
;
std
::
map
<
int
,
OptimizerConfig
::
Optimizer
>
opts_table_
;
OptimizerConfig
config_
;
};
...
...
paddle/optimizer/sgd_optimizer.cc
浏览文件 @
ec65fa83
...
...
@@ -42,7 +42,7 @@ void SGDOptimizer::DeserializeState(const std::string &str) {
state
.
ParseFromString
(
str
);
num_sample_passed_
=
state
.
num_sample_passed
();
ProtoToTensor
(
state
.
parameter
(),
parameter_
);
ProtoToTensor
(
state
.
parameter
(),
momentums_
);
if
(
momentum_
!=
0.0
)
ProtoToTensor
(
state
.
parameter
(),
momentums_
);
}
}
// namespace optimizer
...
...
proto/OptimizerConfig.proto
浏览文件 @
ec65fa83
...
...
@@ -55,12 +55,12 @@ message AdamConfig {
message
ConstLrConfig
{
// learninRate Policy
required
double
learning_rate
=
1
[
default
=
1.0
];
optional
double
learning_rate
=
1
[
default
=
1.0
];
}
message
LinearLrConfig
{
// learninRate Policy
required
double
learning_rate
=
1
[
default
=
1.0
];
optional
double
learning_rate
=
1
[
default
=
1.0
];
optional
double
lr_decay_a
=
2
;
optional
double
lr_decay_b
=
3
;
}
...
...
@@ -74,7 +74,7 @@ enum DataType {
PADDLE_ELEMENT_TYPE_FLOAT32
=
4
;
PADDLE_ELEMENT_TYPE_FLOAT64
=
5
;
}
required
DataType
data_type
=
1
;
optional
DataType
data_type
=
1
;
repeated
bytes
content
=
2
;
}
...
...
@@ -132,7 +132,7 @@ message OptimizerConfig {
Adagrad
=
3
;
Adam
=
4
;
}
required
Optimizer
optimizer
=
1
;
optional
Optimizer
optimizer
=
1
;
optional
SGDConfig
sgd
=
3
;
optional
AdadeltaConfig
adadelta
=
4
;
optional
AdagradConfig
adagrad
=
5
;
...
...
@@ -142,7 +142,7 @@ message OptimizerConfig {
Const
=
0
;
Linear
=
1
;
}
required
LrPolicy
lr_policy
=
11
;
optional
LrPolicy
lr_policy
=
11
;
optional
ConstLrConfig
const_lr
=
12
;
optional
LinearLrConfig
linear_lr
=
13
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录