diff --git a/paddle/go/pserver/optimizer.c b/paddle/go/pserver/optimizer.c index 123684970f943912f1c2dbfdde9fe71ce6c3a9e8..36a612a56f4a5673a93843b30366187bea5ac811 100644 --- a/paddle/go/pserver/optimizer.c +++ b/paddle/go/pserver/optimizer.c @@ -3,34 +3,44 @@ #include "optimizer.h" typedef int (*update_func)(void*, void *, paddle_element_type, const void*, int); +typedef void (*release_func)(void*); typedef struct paddle_optimizer{ - update_func func; + update_func update; + release_func release; void* optimizer; } paddle_optimizer; void paddle_release_optimizer(paddle_optimizer* o) { + o->release(o->optimizer); free(o); } int paddle_update_parameter(paddle_optimizer* o, void *buffer, paddle_element_type element_type, const void* gradient, int num_bytes) { - return o->func(o->optimizer, buffer, element_type, gradient, num_bytes); + return o->update(o->optimizer, buffer, element_type, gradient, num_bytes); } typedef struct { double learning_rate; } SGD_optimizer; -int paddle_SGD_update_parameter(void* optimizer, void *buffer, paddle_element_type element_type, const void* gradient, int num_bytes) { +int update_SGD(void* optimizer, void *buffer, paddle_element_type element_type, const void* gradient, int num_bytes) { + SGD_optimizer* o = (SGD_optimizer*)optimizer; // TODO return 0; } +void release_SGD(void *optimizer) { + SGD_optimizer* o = (SGD_optimizer*)optimizer; + // nothing allocated on heap +} + paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate) { SGD_optimizer* o = (SGD_optimizer*)malloc(sizeof(SGD_optimizer)); o->learning_rate = learning_rate; paddle_optimizer* container = (paddle_optimizer*)malloc(sizeof(paddle_optimizer)); - container->func = paddle_SGD_update_parameter; + container->update = update_SGD; + container->release = release_SGD; container->optimizer = o; return container; } diff --git a/paddle/go/pserver/optimizer_test.go b/paddle/go/pserver/optimizer_test.go new file mode 100644 index 0000000000000000000000000000000000000000..64d6d092aa1864fbca012214ced5e03e157d4a4c --- /dev/null +++ b/paddle/go/pserver/optimizer_test.go @@ -0,0 +1,8 @@ +package pserver + +import "testing" + +func TestSGDCreateRelease(t *testing.T) { + o := newOptimizer(sgd, 1) + o.Cleanup() +}