optimizer.h 3.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
D
dzhwinter 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
D
dzhwinter 已提交
14

D
dzhwinter 已提交
15 16
#pragma once

17 18 19
#include <stdbool.h>
#include <stdint.h>

20 21 22 23 24 25
/**
 * @brief optimizer library in independent with other module
 * which will be used in :
 * Case A, the gradient optimized locally on the trainer.
 *
 * Case B, the gradient optimized on the parameter server.
26 27 28 29 30
 */

#ifdef __cplusplus
extern "C" {
#endif
31

32 33 34 35 36 37 38 39 40
typedef enum {
  PADDLE_ELEMENT_TYPE_INT32 = 0,
  PADDLE_ELEMENT_TYPE_UINT32 = 1,
  PADDLE_ELEMENT_TYPE_INT64 = 2,
  PADDLE_ELEMENT_TYPE_UINT64 = 3,
  PADDLE_ELEMENT_TYPE_FLOAT32 = 4,
  PADDLE_ELEMENT_TYPE_FLOAT64 = 5,
} paddle_element_type;

41 42 43
/**
 * @brief execution status code
 */
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
const int32_t PADDLE_SUCCESS = 0;
const int32_t PADDLE_ERROR = -1;

typedef struct paddle_optimizer paddle_optimizer;
/**
 * this group interface called in order :
 * 1. create optimizer with config
 * 2. set weights
 * 3. update_parameter
 * 4. get_weights
 * 5. release optimizer
 */

/**
 *  @brief create optimizer with proto_config
 *  @param config_proto, optimizer protobuf, see OptimizerConfig.proto in detail
 *  @return return optimizer instance
 */
paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto,
63
                                          const int config_proto_len,
64 65 66 67 68
                                          const paddle_element_type data_type,
                                          void* param_buffer,
                                          int num_bytes,
                                          const char* state,
                                          const int state_len);
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86

/**
 *  @brief release optimizer
 *  @param optimizer
 *  @return return exec status
 */
int paddle_release_optimizer(paddle_optimizer* o);

/**
 *  @brief optimizer instance
 *  @param datatype of gradient and parameter
 *  @param gradient, calculate by optimzizer caller.
 *       TODO(zhihong): just pass loss to reduce communicate overhead.
 *                     Project Adam Ms'14 paper for detail
 *  @param num_bytes, gradient size
 *  @return return exec status
 */
int paddle_update_parameter(paddle_optimizer* o,
87
                            const paddle_element_type data_type,
88 89 90 91
                            const void* gradient,
                            int num_bytes);

/**
D
dzhwinter 已提交
92
 *  @brief optimizer for get parameter buffer
93
 *  @param param_buffer, initilized parameter buffer
94
 *  @return return content length
95
 */
96
int paddle_optimizer_get_weights(paddle_optimizer* o, void** param_buffer);
97 98

/**
D
dzhwinter 已提交
99
 *  @brief optimzizer for saving training state
100 101
 *  @param training state for receive SerializeState
 *  @return return state_buffer length
102
 */
103
int paddle_optimizer_get_state(paddle_optimizer* o, const char** state);
104

105 106 107
#ifdef __cplusplus
}
#endif