From 6ee5bc81c021e063369d1e9ba9333d534219a2cb Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Wed, 17 May 2017 19:51:36 -0400 Subject: [PATCH] use function pointer for updater dispatching --- paddle/go/pserver/optimizer.c | 32 +++++++++++++++++++++++--------- paddle/go/pserver/optimizer.go | 4 ++-- paddle/go/pserver/optimizer.h | 9 ++++----- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/paddle/go/pserver/optimizer.c b/paddle/go/pserver/optimizer.c index d83409297ba..123684970f9 100644 --- a/paddle/go/pserver/optimizer.c +++ b/paddle/go/pserver/optimizer.c @@ -2,21 +2,35 @@ #include "optimizer.h" -typedef struct { - double learning_rate; -} SGD_optimizer; +typedef int (*update_func)(void*, void *, paddle_element_type, const void*, int); -paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate) { - SGD_optimizer* o = (SGD_optimizer*)malloc(sizeof(SGD_optimizer)); - o->learning_rate = learning_rate; - return (paddle_optimizer*)o; -} +typedef struct paddle_optimizer{ + update_func func; + void* optimizer; +} paddle_optimizer; void paddle_release_optimizer(paddle_optimizer* o) { free(o); } -int paddle_update_parameter(paddle_optimizer* o, void *buffer, paddle_element_type datatype, const void* gradient, int num_bytes) { +int paddle_update_parameter(paddle_optimizer* o, void *buffer, paddle_element_type element_type, const void* gradient, int num_bytes) { + return o->func(o->optimizer, buffer, element_type, gradient, num_bytes); +} + +typedef struct { + double learning_rate; +} SGD_optimizer; + +int paddle_SGD_update_parameter(void* optimizer, void *buffer, paddle_element_type element_type, const void* gradient, int num_bytes) { // TODO return 0; } + +paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate) { + SGD_optimizer* o = (SGD_optimizer*)malloc(sizeof(SGD_optimizer)); + o->learning_rate = learning_rate; + paddle_optimizer* container = (paddle_optimizer*)malloc(sizeof(paddle_optimizer)); + container->func = paddle_SGD_update_parameter; + container->optimizer = o; + return container; +} diff --git a/paddle/go/pserver/optimizer.go b/paddle/go/pserver/optimizer.go index aa02bed3e0f..8c6450bca0b 100644 --- a/paddle/go/pserver/optimizer.go +++ b/paddle/go/pserver/optimizer.go @@ -18,7 +18,7 @@ const ( var nullPtr = unsafe.Pointer(uintptr(0)) type optimizer struct { - opt *C.paddle_optimizer + opt *C.struct_paddle_optimizer } func newOptimizer(t optimizerType, learning_rate float64) *optimizer { @@ -46,6 +46,6 @@ func (o *optimizer) UpdateParameter(p Parameter, g Gradient) error { func (o *optimizer) Cleanup() { if unsafe.Pointer(o.opt) != nullPtr { C.paddle_release_optimizer(o.opt) - o.opt = (*C.paddle_optimizer)(nullPtr) + o.opt = (*C.struct_paddle_optimizer)(nullPtr) } } diff --git a/paddle/go/pserver/optimizer.h b/paddle/go/pserver/optimizer.h index e1750ca608e..cde8da70cca 100644 --- a/paddle/go/pserver/optimizer.h +++ b/paddle/go/pserver/optimizer.h @@ -10,10 +10,9 @@ typedef enum { PADDLE_ELEMENT_TYPE_FLOAT64 = 5, } paddle_element_type; -typedef struct paddle_optimizer paddle_optimizer; - -paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate); -void paddle_release_optimizer(paddle_optimizer* o); -int paddle_update_parameter(paddle_optimizer* o, void *buffer, paddle_element_type datatype, const void* gradient, int num_bytes); +struct paddle_optimizer; +struct paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate); +void paddle_release_optimizer(struct paddle_optimizer* o); +int paddle_update_parameter(struct paddle_optimizer* o, void *buffer, paddle_element_type element_type, const void* gradient, int num_bytes); #endif /* PADDLE_PSERVER_OPTIMIZER_H */ -- GitLab