optimizer.cc 3.0 KB
Newer Older
1 2 3 4
#include "optimizer.h"
#include <string>

#include "parameter_optimizer.h"
D
dzhwinter 已提交
5
using namespace paddle;
D
dzhwinter 已提交
6
using namespace paddle::optimizer;
7

D
dzhwinter 已提交
8
template <paddle_element_type VALUE>
9 10 11 12 13 14 15
struct EnumToType {};

template <class T>
struct TypeToEnum {};

#define MATCH_ENUM_TYPE(TYPE, ENUM)                  \
  template <>                                        \
16
  struct TypeToEnum<TYPE> {                          \
17
    static paddle_element_type v() { return ENUM; }; \
18 19 20 21 22 23
    static constexpr TYPE value = ENUM;              \
  };                                                 \
  template <>                                        \
  struct EnumToType<ENUM> {                          \
    typedef TYPE Type;                               \
  }
24 25 26 27 28 29 30 31

MATCH_ENUM_TYPE(int32_t, PADDLE_ELEMENT_TYPE_INT32);
MATCH_ENUM_TYPE(uint32_t, PADDLE_ELEMENT_TYPE_UINT32);
MATCH_ENUM_TYPE(int64_t, PADDLE_ELEMENT_TYPE_INT64);
MATCH_ENUM_TYPE(uint64_t, PADDLE_ELEMENT_TYPE_UINT64);
MATCH_ENUM_TYPE(float, PADDLE_ELEMENT_TYPE_FLOAT32);
MATCH_ENUM_TYPE(double, PADDLE_ELEMENT_TYPE_FLOAT64);

D
dzhwinter 已提交
32 33
struct paddle_optimizer {
  paddle::optimizer::ParameterOptimizer* impl;
34 35 36
};

paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto,
37 38 39
                                          const int config_proto_len,
                                          const char** state,
                                          const int state_size) {
D
dzhwinter 已提交
40
  paddle_optimizer* optimizer = new paddle_optimizer;
41
  std::string config(config_proto, config_proto + config_proto_len);
D
dzhwinter 已提交
42
  optimizer->impl = ParameterOptimizer::Create(config);
43 44 45 46
  if (state != nullptr) {
    std::string s(*state, *state + state_size);
    optimizer->impl->DeSerializeState(s);
  }
47 48 49 50 51 52 53 54 55
  return optimizer;
}

int paddle_release_optimizer(paddle_optimizer* o) {
  if (o != nullptr) delete o->impl;
  return PADDLE_SUCCESS;
}

int paddle_update_parameter(paddle_optimizer* o,
56
                            const paddle_element_type data_type,
57 58
                            const void* grad_buffer,
                            int num_bytes) {
D
dzhwinter 已提交
59
  // TOOD(zhihong): datatype not work. need to add the runtime datatype
D
dzhwinter 已提交
60 61
  auto grad_type = reinterpret_cast<const real*>(grad_buffer);
  Tensor* gradient = new Tensor(const_cast<real*>(grad_type), num_bytes);
D
dzhwinter 已提交
62
  o->impl->Update(gradient);
63 64 65 66
  return PADDLE_SUCCESS;
}

int paddle_optimizer_set_weights(paddle_optimizer* o,
67
                                 const paddle_element_type data_type,
68 69
                                 void* param_buffer,
                                 int num_bytes) {
D
dzhwinter 已提交
70 71
  // TOOD(zhihong): datatype not work. need to add the runtime datatype
  Tensor* param = new Tensor(reinterpret_cast<real*>(param_buffer), num_bytes);
72 73 74 75 76 77 78 79
  o->impl->set_weight(param);
  return PADDLE_SUCCESS;
}

void* paddle_optimizer_get_weights(paddle_optimizer* o) {
  void* buffer = (void*)o->impl->get_weight();
  return buffer;
}
80 81 82 83 84

int paddle_optimizer_get_state(paddle_optimizer* o, const char* state) {
  state = o->impl->SerializeState();
  return PADDLE_SUCCESS;
}