parameter_optimizer_test.cpp 2.8 KB
Newer Older
D
dzhwinter 已提交
1 2
#include "parameter_optimizer.h"
#include <cmath>
D
dzhwinter 已提交
3
#include <map>
D
dzhwinter 已提交
4 5
#include <vector>
#include "gtest/gtest.h"
D
dzhwinter 已提交
6
#include "lr_policy.h"
D
dzhwinter 已提交
7

D
dzhwinter 已提交
8 9 10
using namespace paddle;
using namespace paddle::optimizer;

11
Tensor* FillTensor(size_t size) {
D
dzhwinter 已提交
12
  Tensor* param = new Tensor(size);
D
dzhwinter 已提交
13
  Tensor& p = *param;
D
dzhwinter 已提交
14
  for (size_t i = 0; i < p.size(); ++i) {
D
dzhwinter 已提交
15 16 17 18 19
    p[i] = (float)rand() / (float)RAND_MAX;
  }
  return param;
}

20
Tensor* FixedTensor(size_t size) {
D
dzhwinter 已提交
21
  Tensor* param = new Tensor(size);
D
dzhwinter 已提交
22
  Tensor& p = *param;
D
dzhwinter 已提交
23
  for (size_t i = 0; i < p.size(); ++i) {
D
dzhwinter 已提交
24 25 26 27 28 29 30 31
    p[i] = i;
  }
  return param;
}

class OptimizerTest : public testing::Test {
public:
  // init tensor shape
D
dzhwinter 已提交
32
  const size_t kSize = 5;
D
dzhwinter 已提交
33 34

  virtual void SetUp() {
35 36
    CreateSGD();
    CreateAdam();
D
dzhwinter 已提交
37 38 39
  }
  virtual void TearDown() {}

40
  void CreateSGD() {
D
dzhwinter 已提交
41
    Tensor* parameter = FixedTensor(kSize);
D
dzhwinter 已提交
42 43 44 45
    config_.set_optimizer(OptimizerConfig::SGD);
    config_.mutable_sgd()->set_momentum(0.0);
    config_.mutable_sgd()->set_decay(0.0);
    config_.mutable_sgd()->set_nesterov(false);
D
dzhwinter 已提交
46
    config_.set_lr_policy(OptimizerConfig::Const);
D
dzhwinter 已提交
47
    config_.mutable_const_lr()->set_learning_rate(0.1);
D
dzhwinter 已提交
48 49
    std::string str = config_.SerializeAsString();
    ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter);
D
dzhwinter 已提交
50
    opts_.push_back(opt);
D
dzhwinter 已提交
51 52
  }

53
  void CreateAdam() {
D
dzhwinter 已提交
54 55 56 57 58 59
    Tensor* parameter = FixedTensor(kSize);
    config_.set_optimizer(OptimizerConfig::Adam);
    config_.mutable_adam()->set_beta_1(0.9);
    config_.mutable_adam()->set_beta_2(0.1);
    config_.mutable_adam()->set_epsilon(1e-3);
    config_.mutable_adam()->set_decay(0.0);
D
dzhwinter 已提交
60
    config_.set_lr_policy(OptimizerConfig::Const);
D
dzhwinter 已提交
61
    config_.mutable_const_lr()->set_learning_rate(0.1);
D
dzhwinter 已提交
62 63
    std::string str = config_.SerializeAsString();
    ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter);
D
dzhwinter 已提交
64
    opts_.push_back(opt);
D
dzhwinter 已提交
65 66
  }

67
  void TestGetWeight() {
D
dzhwinter 已提交
68 69
    Tensor* p = FixedTensor(kSize);
    for (size_t i = 0; i < opts_.size(); ++i) {
70
      int s = 0;
D
dzhwinter 已提交
71 72
      float* newp = (float*)opts_[i]->get_weight(&s);
      for (size_t j = 0; j < kSize; ++j) {
D
dzhwinter 已提交
73 74 75 76
        EXPECT_EQ(newp[j], (*p)[j]);
      }
    }
  }
D
dzhwinter 已提交
77

78
  void TestUpdate() {
D
dzhwinter 已提交
79 80 81 82 83 84 85 86 87 88 89
    Tensor* g = FixedTensor(kSize);
    for (size_t i = 0; i < opts_.size(); ++i) {
      opts_[i]->Update(g);
    }
  }

  void TestCheckPoint() {
    for (size_t i = 0; i < opts_.size(); ++i) {
      int state_len = 0;
      std::string state = opts_[i]->SerializeState(&state_len);
      opts_[i]->DeserializeState(state);
D
dzhwinter 已提交
90 91 92 93
    }
  }

private:
D
dzhwinter 已提交
94 95
  std::vector<ParameterOptimizer*> opts_;
  OptimizerConfig config_;
D
dzhwinter 已提交
96 97
};

D
dzhwinter 已提交
98 99
TEST_F(OptimizerTest, TestGetWeight) { TestGetWeight(); }

100
TEST_F(OptimizerTest, TestUpdate) { TestUpdate(); }
D
dzhwinter 已提交
101

D
dzhwinter 已提交
102 103
TEST_F(OptimizerTest, TestCheckPoint) { TestCheckPoint(); }

D
dzhwinter 已提交
104 105 106 107
int main(int argc, char** argv) {
  testing::InitGoogleTest(&argc, argv);
  return RUN_ALL_TESTS();
}