提交 65d9e33b 编写于 作者: D dzhwinter

"modify config name"

上级 df5bc787
......@@ -31,4 +31,3 @@ add_simple_unittest(test_FPException)
add_simple_unittest(test_GpuProfiler)
add_simple_unittest(test_BaseMatrix)
add_simple_unittest(test_Matrix)
add_simple_unittest(test_Matrix2)
......@@ -28,7 +28,8 @@ const char *AdamOptimizer::SerializeState(int *state_len) {
state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*velocitys_, state.mutable_momentums());
TensorToProto(*momentums_, state.mutable_momentums());
TensorToProto(*velocitys_, state.mutable_velocitys());
auto str = state.SerializeAsString();
*state_len = str.size();
return str.c_str();
......
......@@ -21,8 +21,8 @@ public:
double LearningRate(const uint64_t num_sample_passed) {
return learning_rate;
}
const char *SerializeState(int *state_len);
void DeserializeState(const std::string &state);
const char *SerializeState(int *state_len) { return nullptr; }
void DeserializeState(const std::string &state) {}
private:
double learning_rate;
......@@ -35,8 +35,13 @@ public:
double LearningRate(const uint64_t num_sample_passed) {
return std::max(learning_rate - lr_decay_a * num_sample_passed, lr_decay_b);
}
const char *SerializeState(int *state_len);
void DeserializeState(const std::string &state);
const char *SerializeState(int *state_len) {
// TODO(zhihong) : add lr_policy serialization
return nullptr;
}
void DeserializeState(const std::string &state) {
// TODO(zhihong) : add lr_policy serialization
}
private:
double learning_rate;
......
......@@ -13,13 +13,13 @@ namespace optimizer {
ParameterOptimizer *ParameterOptimizer::Create(const std::string &config_proto,
Tensor *parameter) {
paddle::OptimizerConfig config;
CHECK(config.ParseFromString(config_proto) == 0)
CHECK(config.ParseFromString(config_proto) == true)
<< "failed parse optimizer config";
auto select_lr_policy = [=](const OptimizerConfig &config) -> LrPolicy * {
if (config.lr_policy() == OptimizerConfig::ConstLr)
if (config.lr_policy() == OptimizerConfig::Const)
return new ConstLr(config.const_lr().learning_rate());
if (config.lr_policy() == OptimizerConfig::LinearLr)
if (config.lr_policy() == OptimizerConfig::Linear)
return new LinearLr(config.linear_lr().learning_rate(),
config.linear_lr().lr_decay_a(),
config.linear_lr().lr_decay_b());
......
......@@ -2,11 +2,8 @@
#include <cmath>
#include <map>
#include <vector>
#include "adadelta_optimizer.h"
#include "adagrad_optimizer.h"
#include "adam_optimizer.h"
#include "gtest/gtest.h"
#include "sgd_optimizer.h"
#include "lr_policy.h"
using namespace paddle;
using namespace paddle::optimizer;
......@@ -41,12 +38,12 @@ public:
virtual void TearDown() {}
void CreateSGD() {
Tensor* parameter = FillTensor(kSize);
Tensor* parameter = FixedTensor(kSize);
config_.set_optimizer(OptimizerConfig::SGD);
config_.mutable_sgd()->set_momentum(0.0);
config_.mutable_sgd()->set_decay(0.0);
config_.mutable_sgd()->set_nesterov(false);
config_.set_lr_policy(OptimizerConfig::ConstLr);
config_.set_lr_policy(OptimizerConfig::Const);
config_.mutable_const_lr()->set_learning_rate(0.1);
std::string str = config_.SerializeAsString();
......@@ -62,7 +59,7 @@ public:
config_.mutable_adam()->set_beta_2(0.1);
config_.mutable_adam()->set_epsilon(1e-3);
config_.mutable_adam()->set_decay(0.0);
config_.set_lr_policy(OptimizerConfig::ConstLr);
config_.set_lr_policy(OptimizerConfig::Const);
config_.mutable_const_lr()->set_learning_rate(0.1);
std::string str = config_.SerializeAsString();
ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter);
......@@ -90,12 +87,13 @@ public:
void TestCheckPoint() {
std::map<OptimizerConfig::Optimizer, int> expected_state_len = {
{OptimizerConfig::SGD, kSize}, {OptimizerConfig::Adam, kSize * 3},
{OptimizerConfig::SGD, kSize * sizeof(float) + sizeof(double)},
{OptimizerConfig::Adam, kSize * 3 * sizeof(float) + sizeof(double)},
};
for (size_t i = 0; i < opts_.size(); ++i) {
int state_len = 0;
std::string state = opts_[i]->SerializeState(&state_len);
EXPECT_EQ(state_len, expected_state_len[opts_table_[i]]);
EXPECT_EQ(state_len, expected_state_len[opts_table_[i + 1]]);
opts_[i]->DeserializeState(state);
}
}
......
......@@ -29,11 +29,9 @@ void SGDOptimizer::Update(const Tensor *gradient) {
const char *SGDOptimizer::SerializeState(int *state_len) {
SGDOptimizerState state;
// TODO(zhihong) : add lr_policy serialization
state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*momentums_, state.mutable_momentums());
if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums());
auto str = state.SerializeAsString();
*state_len = str.size();
return str.c_str();
......@@ -42,9 +40,7 @@ const char *SGDOptimizer::SerializeState(int *state_len) {
void SGDOptimizer::DeserializeState(const std::string &str) {
SGDOptimizerState state;
state.ParseFromString(str);
// TODO(zhihong) : add lr_policy DeserializeState
num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_);
ProtoToTensor(state.parameter(), momentums_);
}
......
......@@ -53,12 +53,12 @@ message AdamConfig {
optional double decay = 44;
}
message ConstLr {
message ConstLrConfig {
// learninRate Policy
required double learning_rate = 1 [default = 1.0];
}
message LinearLr {
message LinearLrConfig {
// learninRate Policy
required double learning_rate = 1 [default = 1.0];
optional double lr_decay_a = 2;
......@@ -139,12 +139,12 @@ message OptimizerConfig {
optional AdamConfig adam = 6;
enum LrPolicy {
ConstLr = 0;
LinearLr = 1;
Const = 0;
Linear = 1;
}
required LrPolicy lr_policy = 11;
optional ConstLr const_lr = 12;
optional LinearLr linear_lr = 13;
optional ConstLrConfig const_lr = 12;
optional LinearLrConfig linear_lr = 13;
// common config of optimizer
// gradient clip when L2 exceeding value
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册