optimizer.cc 3.0 KB
Newer Older
1 2 3 4
#include "optimizer.h"
#include <string>

#include "parameter_optimizer.h"
5

D
dzhwinter 已提交
6
using namespace paddle;
D
dzhwinter 已提交
7
using namespace paddle::optimizer;
8

D
dzhwinter 已提交
9
template <paddle_element_type VALUE>
10 11 12 13 14 15 16
struct EnumToType {};

template <class T>
struct TypeToEnum {};

#define MATCH_ENUM_TYPE(TYPE, ENUM)                  \
  template <>                                        \
17
  struct TypeToEnum<TYPE> {                          \
18
    static paddle_element_type v() { return ENUM; }; \
19 20 21 22 23 24
    static constexpr TYPE value = ENUM;              \
  };                                                 \
  template <>                                        \
  struct EnumToType<ENUM> {                          \
    typedef TYPE Type;                               \
  }
25 26 27 28 29

MATCH_ENUM_TYPE(int32_t, PADDLE_ELEMENT_TYPE_INT32);
MATCH_ENUM_TYPE(uint32_t, PADDLE_ELEMENT_TYPE_UINT32);
MATCH_ENUM_TYPE(int64_t, PADDLE_ELEMENT_TYPE_INT64);
MATCH_ENUM_TYPE(uint64_t, PADDLE_ELEMENT_TYPE_UINT64);
30
// TODO(zhihong): only implement below type, need to fix
31 32 33
MATCH_ENUM_TYPE(float, PADDLE_ELEMENT_TYPE_FLOAT32);
MATCH_ENUM_TYPE(double, PADDLE_ELEMENT_TYPE_FLOAT64);

D
dzhwinter 已提交
34 35
struct paddle_optimizer {
  paddle::optimizer::ParameterOptimizer* impl;
36 37 38
};

paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto,
39
                                          const int config_proto_len,
40 41 42 43 44
                                          const paddle_element_type data_type,
                                          void* param_buffer,
                                          int num_bytes,
                                          const char* state,
                                          const int state_len) {
D
dzhwinter 已提交
45
  paddle_optimizer* optimizer = new paddle_optimizer;
46
  std::string config(config_proto, config_proto + config_proto_len);
D
dzhwinter 已提交
47 48 49
  Tensor* parameter =
      new Tensor(reinterpret_cast<float*>(param_buffer), num_bytes);
  optimizer->impl = ParameterOptimizer::Create(config, parameter);
50
  if (state != nullptr) {
51
    std::string s(state, state + state_len);
D
dzhwinter 已提交
52
    optimizer->impl->DeserializeState(s);
53
  }
54 55 56 57 58 59 60 61 62
  return optimizer;
}

int paddle_release_optimizer(paddle_optimizer* o) {
  if (o != nullptr) delete o->impl;
  return PADDLE_SUCCESS;
}

int paddle_update_parameter(paddle_optimizer* o,
63
                            const paddle_element_type data_type,
64 65
                            const void* grad_buffer,
                            int num_bytes) {
D
dzhwinter 已提交
66
  // TOOD(zhihong): datatype not work. need to add the runtime datatype
67 68
  auto grad_type = reinterpret_cast<const float*>(grad_buffer);
  Tensor* gradient = new Tensor(const_cast<float*>(grad_type), num_bytes);
D
dzhwinter 已提交
69
  o->impl->Update(gradient);
70 71 72
  return PADDLE_SUCCESS;
}

73 74 75 76
int paddle_optimizer_get_weights(paddle_optimizer* o, void** param_buffer) {
  int param_size = 0;
  *param_buffer = (void*)o->impl->get_weight(&param_size);
  return param_size;
77
}
78

79
int paddle_optimizer_get_state(paddle_optimizer* o, const char** state) {
D
dzhwinter 已提交
80 81 82
  int state_len = 0;
  *state = o->impl->SerializeState(&state_len);
  return state_len;
83
}