optimizer.c 1.7 KB
Newer Older
1 2 3 4
#include <stdlib.h>

#include "optimizer.h"

H
Helin Wang 已提交
5
typedef int (*update_func)(void*, void*, paddle_element_type, const void*, int);
H
Helin Wang 已提交
6
typedef void (*release_func)(void*);
7

H
Helin Wang 已提交
8
typedef struct paddle_optimizer {
H
Helin Wang 已提交
9 10
  update_func update;
  release_func release;
11 12
  void* optimizer;
} paddle_optimizer;
13 14

void paddle_release_optimizer(paddle_optimizer* o) {
H
Helin Wang 已提交
15
  o->release(o->optimizer);
16 17 18
  free(o);
}

H
Helin Wang 已提交
19 20 21 22 23
int paddle_update_parameter(paddle_optimizer* o,
                            void* buffer,
                            paddle_element_type element_type,
                            const void* gradient,
                            int num_bytes) {
H
Helin Wang 已提交
24
  return o->update(o->optimizer, buffer, element_type, gradient, num_bytes);
25 26
}

H
Helin Wang 已提交
27
typedef struct { double learning_rate; } SGD_optimizer;
28

H
Helin Wang 已提交
29 30 31 32 33
int update_SGD(void* optimizer,
               void* buffer,
               paddle_element_type element_type,
               const void* gradient,
               int num_bytes) {
H
Helin Wang 已提交
34
  SGD_optimizer* o = (SGD_optimizer*)optimizer;
35
  // TODO(a simple SGD implement)
Q
qiaolongfei 已提交
36 37
  float* parameter = (float*)buffer;
  float* grad = (float*)gradient;
38

Q
qiaolongfei 已提交
39 40
  int i;
  for (i = 0; i < num_bytes / sizeof(float); ++i) {
41 42
    parameter[i] -= o->learning_rate * grad[i];
  }
43 44
  return 0;
}
45

H
Helin Wang 已提交
46 47 48
void release_SGD(void* optimizer) {
  SGD_optimizer* o = (SGD_optimizer*)optimizer;
  // nothing allocated on heap
H
Helin Wang 已提交
49 50
}

51
paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate) {
H
Helin Wang 已提交
52 53 54 55 56 57 58
  SGD_optimizer* impl = (SGD_optimizer*)malloc(sizeof(SGD_optimizer));
  impl->learning_rate = learning_rate;
  paddle_optimizer* opt = (paddle_optimizer*)malloc(sizeof(paddle_optimizer));
  opt->update = update_SGD;
  opt->release = release_SGD;
  opt->optimizer = impl;
  return opt;
59
}