optimizer.cc 2.6 KB
Newer Older
1 2 3 4
#include "optimizer.h"
#include <string>

#include "parameter_optimizer.h"
D
dzhwinter 已提交
5
using namespace paddle;
D
dzhwinter 已提交
6
using namespace paddle::optimizer;
7

D
dzhwinter 已提交
8
template <paddle_element_type VALUE>
9 10 11 12 13 14 15
struct EnumToType {};

template <class T>
struct TypeToEnum {};

#define MATCH_ENUM_TYPE(TYPE, ENUM)                  \
  template <>                                        \
16
  struct TypeToEnum<TYPE> {                          \
17
    static paddle_element_type v() { return ENUM; }; \
18 19 20 21 22 23
    static constexpr TYPE value = ENUM;              \
  };                                                 \
  template <>                                        \
  struct EnumToType<ENUM> {                          \
    typedef TYPE Type;                               \
  }
24 25 26 27 28 29 30 31

MATCH_ENUM_TYPE(int32_t, PADDLE_ELEMENT_TYPE_INT32);
MATCH_ENUM_TYPE(uint32_t, PADDLE_ELEMENT_TYPE_UINT32);
MATCH_ENUM_TYPE(int64_t, PADDLE_ELEMENT_TYPE_INT64);
MATCH_ENUM_TYPE(uint64_t, PADDLE_ELEMENT_TYPE_UINT64);
MATCH_ENUM_TYPE(float, PADDLE_ELEMENT_TYPE_FLOAT32);
MATCH_ENUM_TYPE(double, PADDLE_ELEMENT_TYPE_FLOAT64);

D
dzhwinter 已提交
32 33
struct paddle_optimizer {
  paddle::optimizer::ParameterOptimizer* impl;
34 35 36 37
};

paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto,
                                          int config_proto_len) {
D
dzhwinter 已提交
38
  paddle_optimizer* optimizer = new paddle_optimizer;
39
  std::string config(config_proto, config_proto + config_proto_len);
D
dzhwinter 已提交
40
  optimizer->impl = ParameterOptimizer::create(config);
41 42 43 44 45 46 47 48 49
  return optimizer;
}

int paddle_release_optimizer(paddle_optimizer* o) {
  if (o != nullptr) delete o->impl;
  return PADDLE_SUCCESS;
}

int paddle_update_parameter(paddle_optimizer* o,
50
                            const paddle_element_type data_type,
51 52
                            const void* grad_buffer,
                            int num_bytes) {
D
dzhwinter 已提交
53
  // TOOD(zhihong): datatype not work. need to add the runtime datatype
D
dzhwinter 已提交
54 55
  auto grad_type = reinterpret_cast<const real*>(grad_buffer);
  Tensor* gradient = new Tensor(const_cast<real*>(grad_type), num_bytes);
56 57 58 59 60
  o->impl->update(gradient);
  return PADDLE_SUCCESS;
}

int paddle_optimizer_set_weights(paddle_optimizer* o,
61
                                 const paddle_element_type data_type,
62 63
                                 void* param_buffer,
                                 int num_bytes) {
D
dzhwinter 已提交
64 65
  // TOOD(zhihong): datatype not work. need to add the runtime datatype
  Tensor* param = new Tensor(reinterpret_cast<real*>(param_buffer), num_bytes);
66 67 68 69 70 71 72 73
  o->impl->set_weight(param);
  return PADDLE_SUCCESS;
}

void* paddle_optimizer_get_weights(paddle_optimizer* o) {
  void* buffer = (void*)o->impl->get_weight();
  return buffer;
}