optimizer.cc 2.5 KB
Newer Older
1 2 3 4 5
#include "optimizer.h"
#include <string>

#include "parameter_optimizer.h"

6
template <paddle_element_type T>
7 8 9 10 11 12 13
struct EnumToType {};

template <class T>
struct TypeToEnum {};

#define MATCH_ENUM_TYPE(TYPE, ENUM)                  \
  template <>                                        \
14
  struct TypeToEnum<TYPE> {                          \
15
    static paddle_element_type v() { return ENUM; }; \
16 17 18 19 20 21
    static constexpr TYPE value = ENUM;              \
  };                                                 \
  template <>                                        \
  struct EnumToType<ENUM> {                          \
    typedef TYPE Type;                               \
  }
22 23 24 25 26 27 28

MATCH_ENUM_TYPE(int32_t, PADDLE_ELEMENT_TYPE_INT32);
MATCH_ENUM_TYPE(uint32_t, PADDLE_ELEMENT_TYPE_UINT32);
MATCH_ENUM_TYPE(int64_t, PADDLE_ELEMENT_TYPE_INT64);
MATCH_ENUM_TYPE(uint64_t, PADDLE_ELEMENT_TYPE_UINT64);
MATCH_ENUM_TYPE(float, PADDLE_ELEMENT_TYPE_FLOAT32);
MATCH_ENUM_TYPE(double, PADDLE_ELEMENT_TYPE_FLOAT64);
29
 struct paddle_optimizer {
30 31
  /*! \brief optmizer in C++ side */

32
  paddle::optimizer::ParameterOptimizerBase* impl;
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
};

paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto,
                                          int config_proto_len) {
  paddle_optimizer* optimizer;
  std::string config(config_proto, config_proto + config_proto_len);
  optimizer->impl->create(config_proto);
  return optimizer;
}

int paddle_release_optimizer(paddle_optimizer* o) {
  if (o != nullptr) delete o->impl;
  return PADDLE_SUCCESS;
}

int paddle_update_parameter(paddle_optimizer* o,
49
                            const paddle_element_type data_type,
50 51 52 53 54 55 56 57 58 59
                            const void* grad_buffer,
                            int num_bytes) {
  auto type = EnumToType<data_type>::Type;
  paddle::Tensor<type> gradient(reinterpret_cast<type*>(grad_buffer),
                                num_bytes);
  o->impl->update(gradient);
  return PADDLE_SUCCESS;
}

int paddle_optimizer_set_weights(paddle_optimizer* o,
60
                                 const paddle_element_type data_type,
61 62 63 64 65 66 67 68 69 70 71 72 73
                                 void* param_buffer,
                                 int num_bytes) {
  auto type = EnumToType<data_type>::Type;
  paddle::Tensor<type>* param = new paddle::Tensor<type>(
      reinterpret_cast<type*>(param_buffer), num_bytes);
  o->impl->set_weight(param);
  return PADDLE_SUCCESS;
}

void* paddle_optimizer_get_weights(paddle_optimizer* o) {
  void* buffer = (void*)o->impl->get_weight();
  return buffer;
}