提交 6ee5bc81 编写于 作者: H Helin Wang

use function pointer for updater dispatching

上级 55217c96
......@@ -2,21 +2,35 @@
#include "optimizer.h"
typedef struct {
double learning_rate;
} SGD_optimizer;
typedef int (*update_func)(void*, void *, paddle_element_type, const void*, int);
paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate) {
SGD_optimizer* o = (SGD_optimizer*)malloc(sizeof(SGD_optimizer));
o->learning_rate = learning_rate;
return (paddle_optimizer*)o;
}
typedef struct paddle_optimizer{
update_func func;
void* optimizer;
} paddle_optimizer;
void paddle_release_optimizer(paddle_optimizer* o) {
free(o);
}
int paddle_update_parameter(paddle_optimizer* o, void *buffer, paddle_element_type datatype, const void* gradient, int num_bytes) {
int paddle_update_parameter(paddle_optimizer* o, void *buffer, paddle_element_type element_type, const void* gradient, int num_bytes) {
return o->func(o->optimizer, buffer, element_type, gradient, num_bytes);
}
typedef struct {
double learning_rate;
} SGD_optimizer;
int paddle_SGD_update_parameter(void* optimizer, void *buffer, paddle_element_type element_type, const void* gradient, int num_bytes) {
// TODO
return 0;
}
paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate) {
SGD_optimizer* o = (SGD_optimizer*)malloc(sizeof(SGD_optimizer));
o->learning_rate = learning_rate;
paddle_optimizer* container = (paddle_optimizer*)malloc(sizeof(paddle_optimizer));
container->func = paddle_SGD_update_parameter;
container->optimizer = o;
return container;
}
......@@ -18,7 +18,7 @@ const (
var nullPtr = unsafe.Pointer(uintptr(0))
type optimizer struct {
opt *C.paddle_optimizer
opt *C.struct_paddle_optimizer
}
func newOptimizer(t optimizerType, learning_rate float64) *optimizer {
......@@ -46,6 +46,6 @@ func (o *optimizer) UpdateParameter(p Parameter, g Gradient) error {
func (o *optimizer) Cleanup() {
if unsafe.Pointer(o.opt) != nullPtr {
C.paddle_release_optimizer(o.opt)
o.opt = (*C.paddle_optimizer)(nullPtr)
o.opt = (*C.struct_paddle_optimizer)(nullPtr)
}
}
......@@ -10,10 +10,9 @@ typedef enum {
PADDLE_ELEMENT_TYPE_FLOAT64 = 5,
} paddle_element_type;
typedef struct paddle_optimizer paddle_optimizer;
paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate);
void paddle_release_optimizer(paddle_optimizer* o);
int paddle_update_parameter(paddle_optimizer* o, void *buffer, paddle_element_type datatype, const void* gradient, int num_bytes);
struct paddle_optimizer;
struct paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate);
void paddle_release_optimizer(struct paddle_optimizer* o);
int paddle_update_parameter(struct paddle_optimizer* o, void *buffer, paddle_element_type element_type, const void* gradient, int num_bytes);
#endif /* PADDLE_PSERVER_OPTIMIZER_H */
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册