test_cclient.c 3.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#include <stdio.h>
#include <stdlib.h>

#include "libpaddle_pserver_cclient.h"

typedef float real;

void fail() {
  // TODO(helin): fix: gtest using cmake is not working, using this
  // hacky way for now.
  printf("test failed.\n");
  exit(-1);
}

void print_parameter(paddle_gradient* param) {
  if (param == NULL) {
    printf("param is NULL!!\n");
  } else {
    printf("==== parameter ====\n");
    printf("name: %s\n", param->name);
    printf("content_len: %d\n", param->content_len);
    printf("content_type: %d\n", param->element_type);
Q
qiaolongfei 已提交
23
    int i;
Q
qiaolongfei 已提交
24
    for (i = 0; i < param->content_len / (int)sizeof(real); ++i) {
Q
qiaolongfei 已提交
25
      printf("%f ", ((float*)param->content)[i]);
26 27 28 29 30 31 32
    }
    printf("\n\n");
  }
}

int main() {
  char addr[] = "localhost:3000";
Q
qiaolongfei 已提交
33
  paddle_pserver_client c = paddle_new_pserver_client(addr, 1);
34 35

  char* names[] = {"param_a", "param_b"};
Q
qiaolongfei 已提交
36

37
retry:
Q
qiaolongfei 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
  printf("init parameter to pserver:\n");

  real param_content1[] = {0.1, 0.2, 0.3};
  real param_content2[] = {0.4, 0.5, 0.6};
  paddle_parameter** params =
          (paddle_parameter**)malloc(sizeof(paddle_parameter*) * 2);
  params[0] = (paddle_parameter*)malloc(sizeof(paddle_parameter));
  params[0]->name = names[0];
  params[0]->content = (unsigned char*)param_content1;
  params[0]->content_len = 3 * sizeof(real);
  params[0]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;

  params[1] = (paddle_parameter*)malloc(sizeof(paddle_parameter));
  params[1]->name = names[1];
  params[1]->content = (unsigned char*)param_content2;
  params[1]->content_len = 3 * sizeof(real);
  params[1]->element_type = PADDLE_ELEMENT_TYPE_INT32;
55 56

  if (paddle_begin_init_params(c)) {
Q
qiaolongfei 已提交
57
    if (paddle_init_param(c, *params[0], NULL, 0) != 0) {
58 59
      goto retry;
    }
Q
qiaolongfei 已提交
60
    if (paddle_init_param(c, *params[1], NULL, 0) != 0) {
61 62 63 64 65 66 67 68 69
      goto retry;
    }
    if (paddle_finish_init_params(c) != 0) {
      goto retry;
    }
  } else {
    fail();
  }

Q
qiaolongfei 已提交
70 71 72
  printf("get inited parameters from pserver:\n");
  // get parameters again by reusing the allocated parameter buffers.
  if (paddle_get_params(c, params, 2) != 0) {
73 74
    fail();
  }
Q
qiaolongfei 已提交
75 76
  print_parameter(params[0]);
  print_parameter(params[1]);
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95

  printf("send gradient to pserver:\n");
  real gradient_content1[] = {0.01, 0.02, 0.03};
  real gradinet_content2[] = {0.04, 0.05, 0.06};

  paddle_gradient** grads =
      (paddle_gradient**)malloc(sizeof(paddle_gradient*) * 2);
  grads[0] = (paddle_gradient*)malloc(sizeof(paddle_gradient));
  grads[0]->name = names[0];
  grads[0]->content = (unsigned char*)gradient_content1;
  grads[0]->content_len = 3 * sizeof(real);
  grads[0]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;

  grads[1] = (paddle_gradient*)malloc(sizeof(paddle_gradient));
  grads[1]->name = names[1];
  grads[1]->content = (unsigned char*)gradinet_content2;
  grads[1]->content_len = 3 * sizeof(real);
  grads[1]->element_type = PADDLE_ELEMENT_TYPE_INT32;

Q
qiaolongfei 已提交
96
  printf("print gradient sent to pserver:\n");
97 98 99 100 101 102 103 104 105
  print_parameter(grads[0]);
  print_parameter(grads[1]);

  if (paddle_send_grads(c, grads, 2) != 0) {
    fail();
  }

  printf("get updated parameters from pserver:\n");
  // get parameters again by reusing the allocated parameter buffers.
Q
qiaolongfei 已提交
106
  if (paddle_get_params(c, params, 2) != 0) {
107 108
    fail();
  }
Q
qiaolongfei 已提交
109 110
  print_parameter(params[0]);
  print_parameter(params[1]);
111 112 113 114 115 116 117

  if (paddle_save_model(c, "/tmp/") != 0) {
    fail();
  }

  return 0;
}