提交 5f9cd8c9 编写于 作者: D dzhwinter

"rename test file name"

上级 26e9c4e2
...@@ -26,3 +26,5 @@ add_dependencies(optimizer gen_proto_cpp) ...@@ -26,3 +26,5 @@ add_dependencies(optimizer gen_proto_cpp)
add_simple_unittest(optimizer_test) add_simple_unittest(optimizer_test)
add_simple_unittest(optimizer_factory_test) add_simple_unittest(optimizer_factory_test)
add_simple_unittest(Tensor_test) add_simple_unittest(Tensor_test)
add_simple_unittest(parameter_optimizer_test)
add_dependencies(parameter_optimizer_test optimizer)
#include "optimizer_factory.h" #include "optimizer_factory.h"
#include "gtest/gtest.h"
#include "parameter_optimizer.h"
#define float TestType;
class OptimizerTest : public testing::Test {
public:
virtual void SetUp() {
paddle::OptimizerConfig config;
config.set_learning_rate(0.01);
config.set_decay(0.0);
config.set_momentum(0.0);
config.set_nesterov(false);
config.set_lr_decay_a(0.9);
config.set_lr_decay_b(0.1);
std::string config_proto = config.SerializeAsString();
ParameterOptimizer<TestType>::create(config_proto, )
}
virtual void TearDown() {}
private:
ParameterOptimizer<TestType>* o;
};
TEST_F(OptimizerTest, createOptimizer) {}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
#include <glog/logging.h> #include <glog/logging.h>
// #include "adadelta_optimizer.h" #include "adadelta_optimizer.h"
// #include "adagrad_optimizer.h" #include "adagrad_optimizer.h"
// #include "adam_optimizer.h" #include "adam_optimizer.h"
#include "lr_policy.h" #include "lr_policy.h"
#include "sgd_optimizer.h" #include "sgd_optimizer.h"
...@@ -18,13 +18,13 @@ ParameterOptimizer *ParameterOptimizer::create( ...@@ -18,13 +18,13 @@ ParameterOptimizer *ParameterOptimizer::create(
auto select_lr_policy = [=](const OptimizerConfig &config) -> BaseLr * { auto select_lr_policy = [=](const OptimizerConfig &config) -> BaseLr * {
std::string s(config.lr_policy()); std::string s(config.lr_policy());
if (s == "ConstLr") return new ConstLr(config.lr_config().learning_rate()); if (s == "ConstLr") return new ConstLr(config.const_lr().learning_rate());
if (s == "LinearLr") if (s == "LinearLr")
return new LinearLr(config.lr_config().learning_rate(), return new LinearLr(config.linear_lr().learning_rate(),
config.lr_config().lr_decay_a(), config.linear_lr().lr_decay_a(),
config.lr_config().lr_decay_b()); config.linear_lr().lr_decay_b());
// default // default
return new ConstLr(config.lr_config().learning_rate()); return nullptr;
}; };
BaseLr *lr = select_lr_policy(config); BaseLr *lr = select_lr_policy(config);
auto select_optimizer = auto select_optimizer =
...@@ -36,20 +36,20 @@ ParameterOptimizer *ParameterOptimizer::create( ...@@ -36,20 +36,20 @@ ParameterOptimizer *ParameterOptimizer::create(
config.sgd().nesterov(), config.sgd().nesterov(),
lr); lr);
} }
// if (s == "Adadelta") { if (s == "Adadelta") {
// return new AdagradOptimizer( return new AdagradOptimizer(
// config.adagrad().epsilon(), config.adagrad().decay(), lr); config.adagrad().epsilon(), config.adagrad().decay(), lr);
// } }
// if (s == "Adagrad") { if (s == "Adagrad") {
// return new AdagradOptimizer( return new AdagradOptimizer(
// config.adagrad().epsilon(), config.adagrad().decay(), lr); config.adagrad().epsilon(), config.adagrad().decay(), lr);
// } }
// if (s == "Adam") { if (s == "Adam") {
// return new AdadeltaOptimizer(config.adadelta().rho(), return new AdadeltaOptimizer(config.adadelta().rho(),
// config.adadelta().epsilon(), config.adadelta().epsilon(),
// config.adadelta().decay(), config.adadelta().decay(),
// lr); lr);
// } }
// default // default
return new SGDOptimizer(config.sgd().momentum(), return new SGDOptimizer(config.sgd().momentum(),
config.sgd().decay(), config.sgd().decay(),
......
...@@ -17,9 +17,6 @@ public: ...@@ -17,9 +17,6 @@ public:
* @brief update hook for algorithm need to traverse parameter more than * @brief update hook for algorithm need to traverse parameter more than
* once. * once.
*/ */
// use config for pack trainig state
ParameterOptimizer(const OptimizerConfig &config) : config_(config){};
ParameterOptimizer(BaseLr *lr) : lr_policy(lr), num_sample_passed(0) {} ParameterOptimizer(BaseLr *lr) : lr_policy(lr), num_sample_passed(0) {}
virtual ~ParameterOptimizer() { delete parameter_; }; virtual ~ParameterOptimizer() { delete parameter_; };
......
#include "parameter_optimizer.h"
#include <cmath>
#include <tuple>
#include <vector>
#include "adadelta_optimizer.h"
#include "adagrad_optimizer.h"
#include "adam_optimizer.h"
#include "gtest/gtest.h"
#include "sgd_optimizer.h"
using namespace paddle;
using namespace paddle::optimizer;
Tensor* fill_n_Tensor(size_t size) {
real* ptr = new real[size];
Tensor* param = new Tensor(ptr, size);
Tensor& p = *param;
for (auto i = 0; i < p.size(); ++i) {
p[i] = (float)rand() / (float)RAND_MAX;
}
return param;
}
Tensor* fix_n_Tensor(size_t size) {
real* ptr = new real[size];
Tensor* param = new Tensor(ptr, size);
Tensor& p = *param;
for (auto i = 0; i < p.size(); ++i) {
p[i] = i;
}
return param;
}
class OptimizerTest : public testing::Test {
public:
// init tensor shape
const size_t size = 5;
virtual void SetUp() {
create_sgd();
create_adam();
}
virtual void TearDown() {}
void create_sgd() {
config.set_optimizer_name("SGD");
config.mutable_sgd()->set_momentum(0.0);
config.mutable_sgd()->set_decay(0.0);
config.mutable_sgd()->set_nesterov(false);
config.set_lr_policy("ConstLr");
config.mutable_const_lr()->set_learning_rate(0.1);
ParameterOptimizer* opt =
ParameterOptimizer::create(config.SerializeAsString());
opts.push_back(opt);
}
void create_adam() {
config.set_optimizer_name("Adam");
config.mutable_adam()->set_beta_1(0.9);
config.mutable_adam()->set_beta_2(0.1);
config.mutable_adam()->set_epsilon(1e-3);
config.mutable_adam()->set_decay(0.0);
config.set_lr_policy("ConstLr");
config.mutable_const_lr()->set_learning_rate(0.1);
ParameterOptimizer* opt =
ParameterOptimizer::create(config.SerializeAsString());
opts.push_back(opt);
}
void test_set_weight() {
Tensor* p = fill_n_Tensor(size);
for (size_t i = 0; i < opts.size(); ++i) {
opts[i]->set_weight(p);
}
}
void test_get_weight() {
Tensor* p = fix_n_Tensor(size);
for (size_t i = 0; i < opts.size(); ++i) {
opts[i]->set_weight(p);
}
for (size_t i = 0; i < opts.size(); ++i) {
real* newp = (real*)opts[i]->get_weight();
for (size_t j = 0; j < size; ++j) {
EXPECT_EQ(newp[j], (*p)[j]);
}
}
}
void test_update() {
Tensor* g = fix_n_Tensor(size);
for (size_t i = 0; i < opts.size(); ++i) {
opts[i]->update(g);
}
}
private:
std::vector<ParameterOptimizer*> opts;
OptimizerConfig config;
};
TEST_F(OptimizerTest, test_set_get_weight) {
test_set_weight();
test_get_weight();
}
TEST_F(OptimizerTest, test_update) { test_update(); }
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
#include "regularizer.h"
namespace paddle {
namespace optimizer {
template <class T>
Regularizer<T>* Regularizer<T>::create(const std::string& config) {
paddle::OptimizerConfig config;
Regularizer<T>* r;
if (config.regularizer_type() == paddle::OptimizerConfig_RegularizerType_L1) {
r = new L1Regularizer<T>(config);
} else if (config.regularizer_type() ==
paddle::OptimizerConfig_RegularizerType_L2) {
r = new L2Regularizer<T>(config);
break;
}
return r;
}
template class L1Regularizer<float>;
template class L1Regularizer<double>;
template class L2Regularizer<float>;
template class L2Regularizer<double>;
} // namespace optimizer
} // namespace paddle
#ifndef PADDLE_OPITMIZER_REGULARIZER_H_
#define PADDLE_OPTIMIZER_REGULARIZER_H_
#include "OptimizerConfig.pb.h"
#include "Tensor.h"
namespace paddle {
namespace optimizer {
/**
* @brief regularizer in L1, L2
*/
template <class T>
class Regularizer {
public:
/**
* @brief regularizer update interface
* @param param need to update
* @return void
*/
static Regularizer *create(const std::string &config);
virtual void update(Tensor<T> &parameter) = 0;
private:
std::string regularizer_name;
OptimizerConfig config_;
};
template <class T>
class L1Regularizer {
public:
void update(Tensor<T> &parameter);
};
template <class T>
class L2Regularizer {
public:
void update(Tensor<T> &parameter);
};
} // namespace optimizer
} // namespace paddle
#endif
...@@ -52,7 +52,12 @@ message AdamConfig { ...@@ -52,7 +52,12 @@ message AdamConfig {
optional double decay = 44; optional double decay = 44;
} }
message LearningRateConfig { message ConstLr {
// learninRate Policy
required double learning_rate = 40 [default = 1.0];
}
message LinearLr {
// learninRate Policy // learninRate Policy
required double learning_rate = 40 [default = 1.0]; required double learning_rate = 40 [default = 1.0];
optional double lr_decay_a = 25; optional double lr_decay_a = 25;
...@@ -62,36 +67,26 @@ message LearningRateConfig { ...@@ -62,36 +67,26 @@ message LearningRateConfig {
message OptimizerConfig { message OptimizerConfig {
// common config of optimizer // common config of optimizer
// algorithm config, type : string
// SGD = 1;
// Adadelta = 2;
// Adagrad = 3;
// Adam = 4;
required string optimizer_name = 1; required string optimizer_name = 1;
// algorithm config
enum OptimizerType {
SGD = 1;
Adadelta = 2;
Adagrad = 3;
Adam = 4;
}
required OptimizerType optimizer_type = 2;
optional SGDConfig sgd = 3; optional SGDConfig sgd = 3;
optional AdadeltaConfig adadelta = 4; optional AdadeltaConfig adadelta = 4;
optional AdagradConfig adagrad = 5; optional AdagradConfig adagrad = 5;
optional AdamConfig adam = 6; optional AdamConfig adam = 6;
// learning rate runtime policy config // learning rate runtime policy config
// lr_policy : string // lr_policy , type : string
// ConstLr = 0; // ConstLr = 0;
// LinearLr = 1; // LinearLr = 1;
required string lr_policy = 11; required string lr_policy = 11;
required LearningRateConfig lr_config = 12; optional ConstLr const_lr = 12;
optional LinearLr linear_lr = 15;
optional uint64 num_sample_passed = 13 [default = 0]; optional uint64 num_sample_passed = 13 [default = 0];
// reqularizer config
enum RegularizerType {
L1 = 1;
L2 = 2;
L1L2 = 3;
}
optional RegularizerType regularizer_type = 21;
// common config of optimizer // common config of optimizer
optional double clipnorm = 101; optional double clipnorm = 101;
optional double clipvalue = 102; optional double clipvalue = 102;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册