optimizer.cc 3.2 KB
Newer Older
1
#include "optimizer.h"
2 3 4
#include <glog/logging.h>
#include <cstdlib>
#include <cstring>
5 6 7
#include <string>

#include "parameter_optimizer.h"
8

D
dzhwinter 已提交
9
using namespace paddle;
D
dzhwinter 已提交
10
using namespace paddle::optimizer;
11

D
dzhwinter 已提交
12
template <paddle_element_type VALUE>
13 14 15 16 17 18 19
struct EnumToType {};

template <class T>
struct TypeToEnum {};

#define MATCH_ENUM_TYPE(TYPE, ENUM)                  \
  template <>                                        \
20
  struct TypeToEnum<TYPE> {                          \
21
    static paddle_element_type v() { return ENUM; }; \
22 23 24 25 26 27
    static constexpr TYPE value = ENUM;              \
  };                                                 \
  template <>                                        \
  struct EnumToType<ENUM> {                          \
    typedef TYPE Type;                               \
  }
28 29 30 31 32

MATCH_ENUM_TYPE(int32_t, PADDLE_ELEMENT_TYPE_INT32);
MATCH_ENUM_TYPE(uint32_t, PADDLE_ELEMENT_TYPE_UINT32);
MATCH_ENUM_TYPE(int64_t, PADDLE_ELEMENT_TYPE_INT64);
MATCH_ENUM_TYPE(uint64_t, PADDLE_ELEMENT_TYPE_UINT64);
33
// TODO(zhihong): only implement below type, need to fix
34 35 36
MATCH_ENUM_TYPE(float, PADDLE_ELEMENT_TYPE_FLOAT32);
MATCH_ENUM_TYPE(double, PADDLE_ELEMENT_TYPE_FLOAT64);

D
dzhwinter 已提交
37 38
struct paddle_optimizer {
  paddle::optimizer::ParameterOptimizer* impl;
39 40 41
};

paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto,
42
                                          const int config_proto_len,
43 44 45 46 47
                                          const paddle_element_type data_type,
                                          void* param_buffer,
                                          int num_bytes,
                                          const char* state,
                                          const int state_len) {
D
dzhwinter 已提交
48
  paddle_optimizer* optimizer = new paddle_optimizer;
49
  std::string config(config_proto, config_proto + config_proto_len);
50 51
  Tensor* parameter = new Tensor(reinterpret_cast<float*>(param_buffer),
                                 num_bytes / sizeof(float));
D
dzhwinter 已提交
52
  optimizer->impl = ParameterOptimizer::Create(config, parameter);
53
  if (state != nullptr) {
54
    std::string s(state, state + state_len);
D
dzhwinter 已提交
55
    optimizer->impl->DeserializeState(s);
56
  }
57 58 59 60 61 62 63 64 65
  return optimizer;
}

int paddle_release_optimizer(paddle_optimizer* o) {
  if (o != nullptr) delete o->impl;
  return PADDLE_SUCCESS;
}

int paddle_update_parameter(paddle_optimizer* o,
66
                            const paddle_element_type data_type,
67 68
                            const void* grad_buffer,
                            int num_bytes) {
D
dzhwinter 已提交
69
  // TOOD(zhihong): datatype not work. need to add the runtime datatype
70
  auto grad_type = reinterpret_cast<const float*>(grad_buffer);
71 72
  Tensor* gradient =
      new Tensor(const_cast<float*>(grad_type), num_bytes / sizeof(float));
D
dzhwinter 已提交
73
  o->impl->Update(gradient);
74 75 76
  return PADDLE_SUCCESS;
}

77 78 79 80
int paddle_optimizer_get_weights(paddle_optimizer* o, void** param_buffer) {
  int param_size = 0;
  *param_buffer = (void*)o->impl->get_weight(&param_size);
  return param_size;
81
}
82

83
int paddle_optimizer_get_state(paddle_optimizer* o, const char** state) {
84 85 86 87 88 89 90 91
  std::string s = o->impl->SerializeState();
  int state_len = s.size();

  if (state_len > 0) {
    *state = (char*)std::malloc(state_len);
    std::memcpy((void*)*state, (const void*)s.c_str(), state_len);
  }

D
dzhwinter 已提交
92
  return state_len;
93
}