#pragma once #include #include /** * @brief optimizer library in independent with other module * which will be used in : * Case A, the gradient optimized locally on the trainer. * * Case B, the gradient optimized on the parameter server. */ #ifdef __cplusplus extern "C" { #endif typedef enum { PADDLE_ELEMENT_TYPE_INT32 = 0, PADDLE_ELEMENT_TYPE_UINT32 = 1, PADDLE_ELEMENT_TYPE_INT64 = 2, PADDLE_ELEMENT_TYPE_UINT64 = 3, PADDLE_ELEMENT_TYPE_FLOAT32 = 4, PADDLE_ELEMENT_TYPE_FLOAT64 = 5, } paddle_element_type; /** * @brief execution status code */ const int32_t PADDLE_SUCCESS = 0; const int32_t PADDLE_ERROR = -1; typedef struct paddle_optimizer paddle_optimizer; /** * this group interface called in order : * 1. create optimizer with config * 2. set weights * 3. update_parameter * 4. get_weights * 5. release optimizer */ /** * @brief create optimizer with proto_config * @param config_proto, optimizer protobuf, see OptimizerConfig.proto in detail * @return return optimizer instance */ paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto, const int config_proto_len, const paddle_element_type data_type, void* param_buffer, int num_bytes, const char* state, const int state_len); /** * @brief release optimizer * @param optimizer * @return return exec status */ int paddle_release_optimizer(paddle_optimizer* o); /** * @brief optimizer instance * @param datatype of gradient and parameter * @param gradient, calculate by optimzizer caller. * TODO(zhihong): just pass loss to reduce communicate overhead. * Project Adam Ms'14 paper for detail * @param num_bytes, gradient size * @return return exec status */ int paddle_update_parameter(paddle_optimizer* o, const paddle_element_type data_type, const void* gradient, int num_bytes); /** * @brief optimizer for get parameter buffer * @param param_buffer, initilized parameter buffer * @return return content length */ int paddle_optimizer_get_weights(paddle_optimizer* o, void** param_buffer); /** * @brief optimzizer for saving training state * @param training state for receive SerializeState * @return return state_buffer length */ int paddle_optimizer_get_state(paddle_optimizer* o, const char** state); #ifdef __cplusplus } #endif