#include "optimizer.h" #include #include "parameter_optimizer.h" using namespace paddle; using namespace paddle::optimizer; template struct EnumToType {}; template struct TypeToEnum {}; #define MATCH_ENUM_TYPE(TYPE, ENUM) \ template <> \ struct TypeToEnum { \ static paddle_element_type v() { return ENUM; }; \ static constexpr TYPE value = ENUM; \ }; \ template <> \ struct EnumToType { \ typedef TYPE Type; \ } MATCH_ENUM_TYPE(int32_t, PADDLE_ELEMENT_TYPE_INT32); MATCH_ENUM_TYPE(uint32_t, PADDLE_ELEMENT_TYPE_UINT32); MATCH_ENUM_TYPE(int64_t, PADDLE_ELEMENT_TYPE_INT64); MATCH_ENUM_TYPE(uint64_t, PADDLE_ELEMENT_TYPE_UINT64); MATCH_ENUM_TYPE(float, PADDLE_ELEMENT_TYPE_FLOAT32); MATCH_ENUM_TYPE(double, PADDLE_ELEMENT_TYPE_FLOAT64); struct paddle_optimizer { paddle::optimizer::ParameterOptimizer* impl; }; paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto, const int config_proto_len, const char** state, const int state_size) { paddle_optimizer* optimizer = new paddle_optimizer; std::string config(config_proto, config_proto + config_proto_len); optimizer->impl = ParameterOptimizer::Create(config); if (state != nullptr) { std::string s(*state, *state + state_size); optimizer->impl->DeSerializeState(s); } return optimizer; } int paddle_release_optimizer(paddle_optimizer* o) { if (o != nullptr) delete o->impl; return PADDLE_SUCCESS; } int paddle_update_parameter(paddle_optimizer* o, const paddle_element_type data_type, const void* grad_buffer, int num_bytes) { // TOOD(zhihong): datatype not work. need to add the runtime datatype auto grad_type = reinterpret_cast(grad_buffer); Tensor* gradient = new Tensor(const_cast(grad_type), num_bytes); o->impl->Update(gradient); return PADDLE_SUCCESS; } int paddle_optimizer_set_weights(paddle_optimizer* o, const paddle_element_type data_type, void* param_buffer, int num_bytes) { // TOOD(zhihong): datatype not work. need to add the runtime datatype Tensor* param = new Tensor(reinterpret_cast(param_buffer), num_bytes); o->impl->set_weight(param); return PADDLE_SUCCESS; } void* paddle_optimizer_get_weights(paddle_optimizer* o) { void* buffer = (void*)o->impl->get_weight(); return buffer; } int paddle_optimizer_get_state(paddle_optimizer* o, const char* state) { state = o->impl->SerializeState(); return PADDLE_SUCCESS; }