gradient_machine.h 2.2 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
#ifndef __PADDLE_CAPI_GRADIENT_MACHINE_H__
#define __PADDLE_CAPI_GRADIENT_MACHINE_H__
#include "arguments.h"
#include "config.h"
#include "error.h"

#ifdef __cplusplus
extern "C" {
#endif
/**
 * @brief GradientMachine means a neural network.
 */
typedef void* paddle_gradient_machine;

/**
 * @brief Create a gradient machine used for model inference.
 * @param [out] machine that used for model inference.
 * @param [in] modelConfigProtobuf
 * @param [in] size
 * @return paddle_error
 */
PD_API paddle_error paddle_gradient_machine_create_for_inference(
    paddle_gradient_machine* machine, void* modelConfigProtobuf, int size);

/**
 * @brief Load parameter from disk.
 * @param machine Gradient Machine.
 * @param path local directory path.
 * @return paddle_error
 */
PD_API paddle_error paddle_gradient_machine_load_parameter_from_disk(
    paddle_gradient_machine machine, const char* path);

/**
 * @brief Forward a gradient machine
 * @param machine Gradient machine
 * @param inArgs input arguments
 * @param outArgs output arguments
 * @param isTrain is train or not
 * @return paddle_error
 */
PD_API paddle_error
paddle_gradient_machine_forward(paddle_gradient_machine machine,
                                paddle_arguments inArgs,
                                paddle_arguments outArgs,
                                bool isTrain);

/**
 * @brief Create a gradient machine, which parameters are shared from another
 *        gradient machine.
 * @param [in] origin gradient machine
 * @param [in] modelConfigProtobuf model config protobuf
 * @param [in] size of model config buffer.
 * @param [out] slave gradient machine, the output value.
 * @return paddle_error
 */
PD_API paddle_error
paddle_gradient_machine_create_shared_param(paddle_gradient_machine origin,
                                            void* modelConfigProtobuf,
                                            int size,
                                            paddle_gradient_machine* slave);

/**
 * @brief Destroy a gradient machine
 * @param machine that need to destroy
 * @return paddle_error
 */
PD_API paddle_error
paddle_gradient_machine_destroy(paddle_gradient_machine machine);

#ifdef __cplusplus
}
#endif
#endif