parameter_optimizer_test.cpp 3.2 KB
Newer Older
D
dzhwinter 已提交
1 2
#include "parameter_optimizer.h"
#include <cmath>
D
dzhwinter 已提交
3
#include <map>
D
dzhwinter 已提交
4 5
#include <vector>
#include "gtest/gtest.h"
D
dzhwinter 已提交
6
#include "lr_policy.h"
D
dzhwinter 已提交
7

D
dzhwinter 已提交
8 9 10
using namespace paddle;
using namespace paddle::optimizer;

11
Tensor* FillTensor(size_t size) {
D
dzhwinter 已提交
12
  Tensor* param = new Tensor(size);
D
dzhwinter 已提交
13
  Tensor& p = *param;
D
dzhwinter 已提交
14
  for (size_t i = 0; i < p.size(); ++i) {
D
dzhwinter 已提交
15 16 17 18 19
    p[i] = (float)rand() / (float)RAND_MAX;
  }
  return param;
}

20
Tensor* FixedTensor(size_t size) {
D
dzhwinter 已提交
21
  Tensor* param = new Tensor(size);
D
dzhwinter 已提交
22
  Tensor& p = *param;
D
dzhwinter 已提交
23
  for (size_t i = 0; i < p.size(); ++i) {
D
dzhwinter 已提交
24 25 26 27 28 29 30 31
    p[i] = i;
  }
  return param;
}

class OptimizerTest : public testing::Test {
public:
  // init tensor shape
D
dzhwinter 已提交
32
  const size_t kSize = 5;
D
dzhwinter 已提交
33 34

  virtual void SetUp() {
35 36
    CreateSGD();
    CreateAdam();
D
dzhwinter 已提交
37 38 39
  }
  virtual void TearDown() {}

40
  void CreateSGD() {
D
dzhwinter 已提交
41
    Tensor* parameter = FixedTensor(kSize);
D
dzhwinter 已提交
42 43 44 45
    config_.set_optimizer(OptimizerConfig::SGD);
    config_.mutable_sgd()->set_momentum(0.0);
    config_.mutable_sgd()->set_decay(0.0);
    config_.mutable_sgd()->set_nesterov(false);
D
dzhwinter 已提交
46
    config_.set_lr_policy(OptimizerConfig::Const);
D
dzhwinter 已提交
47
    config_.mutable_const_lr()->set_learning_rate(0.1);
D
dzhwinter 已提交
48

D
dzhwinter 已提交
49 50
    std::string str = config_.SerializeAsString();
    ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter);
D
dzhwinter 已提交
51 52
    opts_.push_back(opt);
    opts_table_[opts_.size()] = OptimizerConfig::SGD;
D
dzhwinter 已提交
53 54
  }

55
  void CreateAdam() {
D
dzhwinter 已提交
56 57 58 59 60 61
    Tensor* parameter = FixedTensor(kSize);
    config_.set_optimizer(OptimizerConfig::Adam);
    config_.mutable_adam()->set_beta_1(0.9);
    config_.mutable_adam()->set_beta_2(0.1);
    config_.mutable_adam()->set_epsilon(1e-3);
    config_.mutable_adam()->set_decay(0.0);
D
dzhwinter 已提交
62
    config_.set_lr_policy(OptimizerConfig::Const);
D
dzhwinter 已提交
63
    config_.mutable_const_lr()->set_learning_rate(0.1);
D
dzhwinter 已提交
64 65
    std::string str = config_.SerializeAsString();
    ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter);
D
dzhwinter 已提交
66 67
    opts_.push_back(opt);
    opts_table_[opts_.size()] = OptimizerConfig::Adam;
D
dzhwinter 已提交
68 69
  }

70
  void TestGetWeight() {
D
dzhwinter 已提交
71 72
    Tensor* p = FixedTensor(kSize);
    for (size_t i = 0; i < opts_.size(); ++i) {
73
      int s = 0;
D
dzhwinter 已提交
74 75
      float* newp = (float*)opts_[i]->get_weight(&s);
      for (size_t j = 0; j < kSize; ++j) {
D
dzhwinter 已提交
76 77 78 79
        EXPECT_EQ(newp[j], (*p)[j]);
      }
    }
  }
D
dzhwinter 已提交
80

81
  void TestUpdate() {
D
dzhwinter 已提交
82 83 84 85 86 87 88 89
    Tensor* g = FixedTensor(kSize);
    for (size_t i = 0; i < opts_.size(); ++i) {
      opts_[i]->Update(g);
    }
  }

  void TestCheckPoint() {
    std::map<OptimizerConfig::Optimizer, int> expected_state_len = {
D
dzhwinter 已提交
90 91
        {OptimizerConfig::SGD, kSize * sizeof(float) + sizeof(double)},
        {OptimizerConfig::Adam, kSize * 3 * sizeof(float) + sizeof(double)},
D
dzhwinter 已提交
92 93 94 95
    };
    for (size_t i = 0; i < opts_.size(); ++i) {
      int state_len = 0;
      std::string state = opts_[i]->SerializeState(&state_len);
D
dzhwinter 已提交
96
      EXPECT_EQ(state_len, expected_state_len[opts_table_[i + 1]]);
D
dzhwinter 已提交
97
      opts_[i]->DeserializeState(state);
D
dzhwinter 已提交
98 99 100 101
    }
  }

private:
D
dzhwinter 已提交
102 103 104
  std::vector<ParameterOptimizer*> opts_;
  std::map<int, OptimizerConfig::Optimizer> opts_table_;
  OptimizerConfig config_;
D
dzhwinter 已提交
105 106
};

D
dzhwinter 已提交
107 108
TEST_F(OptimizerTest, TestGetWeight) { TestGetWeight(); }

109
TEST_F(OptimizerTest, TestUpdate) { TestUpdate(); }
D
dzhwinter 已提交
110

D
dzhwinter 已提交
111 112
TEST_F(OptimizerTest, TestCheckPoint) { TestCheckPoint(); }

D
dzhwinter 已提交
113 114 115 116
int main(int argc, char** argv) {
  testing::InitGoogleTest(&argc, argv);
  return RUN_ALL_TESTS();
}