optimizer.h 2.6 KB
Newer Older
D
dzhwinter 已提交
1 2
#pragma once

3 4 5
#include <stdbool.h>
#include <stdint.h>

6 7 8 9 10 11
/**
 * @brief optimizer library in independent with other module
 * which will be used in :
 * Case A, the gradient optimized locally on the trainer.
 *
 * Case B, the gradient optimized on the parameter server.
12 13 14 15 16
 */

#ifdef __cplusplus
extern "C" {
#endif
17

18 19 20 21 22 23 24 25 26
typedef enum {
  PADDLE_ELEMENT_TYPE_INT32 = 0,
  PADDLE_ELEMENT_TYPE_UINT32 = 1,
  PADDLE_ELEMENT_TYPE_INT64 = 2,
  PADDLE_ELEMENT_TYPE_UINT64 = 3,
  PADDLE_ELEMENT_TYPE_FLOAT32 = 4,
  PADDLE_ELEMENT_TYPE_FLOAT64 = 5,
} paddle_element_type;

27 28 29
/**
 * @brief execution status code
 */
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
const int32_t PADDLE_SUCCESS = 0;
const int32_t PADDLE_ERROR = -1;

typedef struct paddle_optimizer paddle_optimizer;
/**
 * this group interface called in order :
 * 1. create optimizer with config
 * 2. set weights
 * 3. update_parameter
 * 4. get_weights
 * 5. release optimizer
 */

/**
 *  @brief create optimizer with proto_config
 *  @param config_proto, optimizer protobuf, see OptimizerConfig.proto in detail
 *  @return return optimizer instance
 */
paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto,
49
                                          const int config_proto_len,
50 51 52 53 54
                                          const paddle_element_type data_type,
                                          void* param_buffer,
                                          int num_bytes,
                                          const char* state,
                                          const int state_len);
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72

/**
 *  @brief release optimizer
 *  @param optimizer
 *  @return return exec status
 */
int paddle_release_optimizer(paddle_optimizer* o);

/**
 *  @brief optimizer instance
 *  @param datatype of gradient and parameter
 *  @param gradient, calculate by optimzizer caller.
 *       TODO(zhihong): just pass loss to reduce communicate overhead.
 *                     Project Adam Ms'14 paper for detail
 *  @param num_bytes, gradient size
 *  @return return exec status
 */
int paddle_update_parameter(paddle_optimizer* o,
73
                            const paddle_element_type data_type,
74 75 76 77 78 79
                            const void* gradient,
                            int num_bytes);

/**
 *  @brief optimizer instance
 *  @param param_buffer, initilized parameter buffer
80
 *  @return return content length
81
 */
82
int paddle_optimizer_get_weights(paddle_optimizer* o, void** param_buffer);
83 84

/**
85 86 87
 *  @brief optimzizer instance
 *  @param training state for receive SerializeState
 *  @return return state_buffer length
88
 */
89
int paddle_optimizer_get_state(paddle_optimizer* o, const char** state);
90

91 92 93
#ifdef __cplusplus
}
#endif