optimizer.cc 2.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
#include "optimizer.h"
#include <string>

#include "parameter_optimizer.h"

template <class T>
struct EnumToType {};

template <class T>
struct TypeToEnum {};

#define MATCH_ENUM_TYPE(TYPE, ENUM)                  \
  template <>                                        \
  struct TypeToEnum<ENUM> {                          \
    static paddle_element_type v() { return ENUM; }; \
    static constexpr TYPE value = ENUM;
}
;
template <>
struct EnumToType<ENUM> {
  typedef TYPE Type;
}

MATCH_ENUM_TYPE(int32_t, PADDLE_ELEMENT_TYPE_INT32);
MATCH_ENUM_TYPE(uint32_t, PADDLE_ELEMENT_TYPE_UINT32);
MATCH_ENUM_TYPE(int64_t, PADDLE_ELEMENT_TYPE_INT64);
MATCH_ENUM_TYPE(uint64_t, PADDLE_ELEMENT_TYPE_UINT64);
MATCH_ENUM_TYPE(float, PADDLE_ELEMENT_TYPE_FLOAT32);
MATCH_ENUM_TYPE(double, PADDLE_ELEMENT_TYPE_FLOAT64);

struct paddle_optimizer {
  /*! \brief optmizer in C++ side */

  paddle::optimizer::ParameterOptimzier* impl;
};

paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto,
                                          int config_proto_len) {
  paddle_optimizer* optimizer;
  std::string config(config_proto, config_proto + config_proto_len);
  optimizer->impl->create(config_proto);
  return optimizer;
}

int paddle_release_optimizer(paddle_optimizer* o) {
  if (o != nullptr) delete o->impl;
  return PADDLE_SUCCESS;
}

int paddle_update_parameter(paddle_optimizer* o,
                            paddle_element_type data_type,
                            const void* grad_buffer,
                            int num_bytes) {
  auto type = EnumToType<data_type>::Type;
  paddle::Tensor<type> gradient(reinterpret_cast<type*>(grad_buffer),
                                num_bytes);
  o->impl->update(gradient);
  return PADDLE_SUCCESS;
}

int paddle_optimizer_set_weights(paddle_optimizer* o,
                                 paddle_element_type data_type,
                                 void* param_buffer,
                                 int num_bytes) {
  auto type = EnumToType<data_type>::Type;
  paddle::Tensor<type>* param = new paddle::Tensor<type>(
      reinterpret_cast<type*>(param_buffer), num_bytes);
  o->impl->set_weight(param);
  return PADDLE_SUCCESS;
}

void* paddle_optimizer_get_weights(paddle_optimizer* o) {
  void* buffer = (void*)o->impl->get_weight();
  return buffer;
}