optimizer.cc 2.5 KB
Newer Older
1 2 3 4
#include "optimizer.h"
#include <string>

#include "parameter_optimizer.h"
D
dzhwinter 已提交
5
using namespace paddle::optimizer;
6

D
dzhwinter 已提交
7
template <paddle_element_type VALUE>
8 9 10 11 12 13 14
struct EnumToType {};

template <class T>
struct TypeToEnum {};

#define MATCH_ENUM_TYPE(TYPE, ENUM)                  \
  template <>                                        \
15
  struct TypeToEnum<TYPE> {                          \
16
    static paddle_element_type v() { return ENUM; }; \
17 18 19 20 21 22
    static constexpr TYPE value = ENUM;              \
  };                                                 \
  template <>                                        \
  struct EnumToType<ENUM> {                          \
    typedef TYPE Type;                               \
  }
23 24 25 26 27 28 29 30

MATCH_ENUM_TYPE(int32_t, PADDLE_ELEMENT_TYPE_INT32);
MATCH_ENUM_TYPE(uint32_t, PADDLE_ELEMENT_TYPE_UINT32);
MATCH_ENUM_TYPE(int64_t, PADDLE_ELEMENT_TYPE_INT64);
MATCH_ENUM_TYPE(uint64_t, PADDLE_ELEMENT_TYPE_UINT64);
MATCH_ENUM_TYPE(float, PADDLE_ELEMENT_TYPE_FLOAT32);
MATCH_ENUM_TYPE(double, PADDLE_ELEMENT_TYPE_FLOAT64);

D
dzhwinter 已提交
31 32
struct paddle_optimizer {
  paddle::optimizer::ParameterOptimizer* impl;
33 34 35 36
};

paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto,
                                          int config_proto_len) {
D
dzhwinter 已提交
37
  paddle_optimizer* optimizer = new paddle_optimizer;
38
  std::string config(config_proto, config_proto + config_proto_len);
D
dzhwinter 已提交
39
  optimizer->impl = ParameterOptimizer::create(config);
40 41 42 43 44 45 46 47 48
  return optimizer;
}

int paddle_release_optimizer(paddle_optimizer* o) {
  if (o != nullptr) delete o->impl;
  return PADDLE_SUCCESS;
}

int paddle_update_parameter(paddle_optimizer* o,
49
                            const paddle_element_type data_type,
50 51
                            const void* grad_buffer,
                            int num_bytes) {
D
dzhwinter 已提交
52 53 54
  // TOOD(zhihong): datatype not work. need to add the runtime datatype
  auto grad = reinterpret_cast<const real*>(grad_buffer);
  Tensor gradient(const_cast<real*>(grad), num_bytes);
55 56 57 58 59
  o->impl->update(gradient);
  return PADDLE_SUCCESS;
}

int paddle_optimizer_set_weights(paddle_optimizer* o,
60
                                 const paddle_element_type data_type,
61 62
                                 void* param_buffer,
                                 int num_bytes) {
D
dzhwinter 已提交
63 64
  // TOOD(zhihong): datatype not work. need to add the runtime datatype
  Tensor* param = new Tensor(reinterpret_cast<real*>(param_buffer), num_bytes);
65 66 67 68 69 70 71 72
  o->impl->set_weight(param);
  return PADDLE_SUCCESS;
}

void* paddle_optimizer_get_weights(paddle_optimizer* o) {
  void* buffer = (void*)o->impl->get_weight();
  return buffer;
}