parameter_optimizer_test.cpp 3.2 KB
Newer Older
D
dzhwinter 已提交
1 2
#include "parameter_optimizer.h"
#include <cmath>
D
dzhwinter 已提交
3
#include <map>
D
dzhwinter 已提交
4 5 6 7 8 9
#include <vector>
#include "adadelta_optimizer.h"
#include "adagrad_optimizer.h"
#include "adam_optimizer.h"
#include "gtest/gtest.h"
#include "sgd_optimizer.h"
D
dzhwinter 已提交
10

D
dzhwinter 已提交
11 12 13
using namespace paddle;
using namespace paddle::optimizer;

14
Tensor* FillTensor(size_t size) {
D
dzhwinter 已提交
15
  Tensor* param = new Tensor(size);
D
dzhwinter 已提交
16
  Tensor& p = *param;
D
dzhwinter 已提交
17
  for (size_t i = 0; i < p.size(); ++i) {
D
dzhwinter 已提交
18 19 20 21 22
    p[i] = (float)rand() / (float)RAND_MAX;
  }
  return param;
}

23
Tensor* FixedTensor(size_t size) {
D
dzhwinter 已提交
24
  Tensor* param = new Tensor(size);
D
dzhwinter 已提交
25
  Tensor& p = *param;
D
dzhwinter 已提交
26
  for (size_t i = 0; i < p.size(); ++i) {
D
dzhwinter 已提交
27 28 29 30 31 32 33 34
    p[i] = i;
  }
  return param;
}

class OptimizerTest : public testing::Test {
public:
  // init tensor shape
D
dzhwinter 已提交
35
  const size_t kSize = 5;
D
dzhwinter 已提交
36 37

  virtual void SetUp() {
38 39
    CreateSGD();
    CreateAdam();
D
dzhwinter 已提交
40 41 42
  }
  virtual void TearDown() {}

43
  void CreateSGD() {
D
dzhwinter 已提交
44 45 46 47 48 49 50
    Tensor* parameter = FillTensor(kSize);
    config_.set_optimizer(OptimizerConfig::SGD);
    config_.mutable_sgd()->set_momentum(0.0);
    config_.mutable_sgd()->set_decay(0.0);
    config_.mutable_sgd()->set_nesterov(false);
    config_.set_lr_policy(OptimizerConfig::ConstLr);
    config_.mutable_const_lr()->set_learning_rate(0.1);
D
dzhwinter 已提交
51 52

    ParameterOptimizer* opt =
D
dzhwinter 已提交
53 54 55
        ParameterOptimizer::Create(config_.SerializeAsString(), parameter);
    opts_.push_back(opt);
    opts_table_[opts_.size()] = OptimizerConfig::SGD;
D
dzhwinter 已提交
56 57
  }

58
  void CreateAdam() {
D
dzhwinter 已提交
59 60 61 62 63 64 65 66
    Tensor* parameter = FixedTensor(kSize);
    config_.set_optimizer(OptimizerConfig::Adam);
    config_.mutable_adam()->set_beta_1(0.9);
    config_.mutable_adam()->set_beta_2(0.1);
    config_.mutable_adam()->set_epsilon(1e-3);
    config_.mutable_adam()->set_decay(0.0);
    config_.set_lr_policy(OptimizerConfig::ConstLr);
    config_.mutable_const_lr()->set_learning_rate(0.1);
D
dzhwinter 已提交
67
    ParameterOptimizer* opt =
D
dzhwinter 已提交
68 69 70
        ParameterOptimizer::Create(config_.SerializeAsString(), parameter);
    opts_.push_back(opt);
    opts_table_[opts_.size()] = OptimizerConfig::Adam;
D
dzhwinter 已提交
71 72
  }

73
  void TestGetWeight() {
D
dzhwinter 已提交
74 75
    Tensor* p = FixedTensor(kSize);
    for (size_t i = 0; i < opts_.size(); ++i) {
76
      int s = 0;
D
dzhwinter 已提交
77 78
      float* newp = (float*)opts_[i]->get_weight(&s);
      for (size_t j = 0; j < kSize; ++j) {
D
dzhwinter 已提交
79 80 81 82
        EXPECT_EQ(newp[j], (*p)[j]);
      }
    }
  }
D
dzhwinter 已提交
83

84
  void TestUpdate() {
D
dzhwinter 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
    Tensor* g = FixedTensor(kSize);
    for (size_t i = 0; i < opts_.size(); ++i) {
      opts_[i]->Update(g);
    }
  }

  void TestCheckPoint() {
    std::map<OptimizerConfig::Optimizer, int> expected_state_len = {
        {OptimizerConfig::SGD, kSize}, {OptimizerConfig::Adam, kSize * 3},
    };
    for (size_t i = 0; i < opts_.size(); ++i) {
      int state_len = 0;
      std::string state = opts_[i]->SerializeState(&state_len);
      EXPECT_EQ(state_len, expected_state_len[opts_table_[i]]);
      opts_[i]->DeserializeState(state);
D
dzhwinter 已提交
100 101 102 103
    }
  }

private:
D
dzhwinter 已提交
104 105 106
  std::vector<ParameterOptimizer*> opts_;
  std::map<int, OptimizerConfig::Optimizer> opts_table_;
  OptimizerConfig config_;
D
dzhwinter 已提交
107 108
};

D
dzhwinter 已提交
109 110
TEST_F(OptimizerTest, TestGetWeight) { TestGetWeight(); }

111
TEST_F(OptimizerTest, TestUpdate) { TestUpdate(); }
D
dzhwinter 已提交
112

D
dzhwinter 已提交
113 114
TEST_F(OptimizerTest, TestCheckPoint) { TestCheckPoint(); }

D
dzhwinter 已提交
115 116 117 118
int main(int argc, char** argv) {
  testing::InitGoogleTest(&argc, argv);
  return RUN_ALL_TESTS();
}