cclient.go 6.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
package main

/*
#include <string.h>
typedef enum {
  PADDLE_ELEMENT_TYPE_INT32   = 0,
  PADDLE_ELEMENT_TYPE_UINT32  = 1,
  PADDLE_ELEMENT_TYPE_INT64   = 2,
  PADDLE_ELEMENT_TYPE_UINT64  = 3,
  PADDLE_ELEMENT_TYPE_FLOAT32 = 4,
  PADDLE_ELEMENT_TYPE_FLOAT64 = 5,
} paddle_element_type;

typedef struct {
  char*               name;
  paddle_element_type element_type;
17
  unsigned char*      content;
18 19 20
  int                 content_len;
} paddle_parameter, paddle_gradient;

21
typedef int paddle_pserver_client;
H
Helin Wang 已提交
22 23
#define PSERVER_ERROR -1
#define PSERVER_OK 0
24 25 26 27
*/
import "C"

import (
28
	"strings"
29 30 31
	"sync"
	"unsafe"

32
	"github.com/PaddlePaddle/Paddle/go/pserver"
H
Helin Wang 已提交
33
	log "github.com/sirupsen/logrus"
34 35
)

H
Helin Wang 已提交
36
var nullPtr = unsafe.Pointer(uintptr(0))
37
var mu sync.Mutex
38 39
var handleMap = make(map[C.paddle_pserver_client]*pserver.Client)
var curHandle C.paddle_pserver_client
40

41
func add(c *pserver.Client) C.paddle_pserver_client {
42 43 44 45 46 47 48 49
	mu.Lock()
	defer mu.Unlock()
	client := curHandle
	curHandle++
	handleMap[client] = c
	return client
}

50
func get(client C.paddle_pserver_client) *pserver.Client {
51 52 53 54 55
	mu.Lock()
	defer mu.Unlock()
	return handleMap[client]
}

56
func remove(client C.paddle_pserver_client) *pserver.Client {
57 58 59 60 61 62 63 64
	mu.Lock()
	defer mu.Unlock()
	h := handleMap[client]
	delete(handleMap, client)
	return h
}

func cArrayToSlice(p unsafe.Pointer, len int) []byte {
H
Helin Wang 已提交
65 66 67 68
	if p == nullPtr {
		return nil
	}

69 70 71 72
	// create a Go clice backed by a C array, reference:
	// https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
	//
	// Go garbage collector will not interact with this data, need
H
Helin Wang 已提交
73
	// to be freed properly.
74 75 76
	return (*[1 << 30]byte)(p)[:len:len]
}

77 78 79 80 81 82 83 84 85 86 87 88
type selector bool

func (s selector) Select() bool {
	return bool(s)
}

type lister []pserver.Server

func (l lister) List() []pserver.Server {
	return l
}

89
//export paddle_new_pserver_client
90
func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_client {
91 92 93 94 95 96 97
	a := C.GoString(addrs)
	as := strings.Split(a, ",")
	servers := make([]pserver.Server, len(as))
	for i := range as {
		servers[i].Index = i
		servers[i].Addr = as[i]
	}
H
Helin Wang 已提交
98
	c := pserver.NewClient(lister(servers), len(as), selector(selected != 0))
99 100 101
	return add(c)
}

102
//export paddle_new_etcd_pserver_client
103
func paddle_new_etcd_pserver_client(etcd_addr *C.char) C.paddle_pserver_client {
104 105 106 107
	// TODO(helin): fault tolerant pserver client using etcd.
	panic("not implemented.")
}

108
//export paddle_pserver_client_release
109
func paddle_pserver_client_release(client C.paddle_pserver_client) {
110
	remove(client)
111 112 113
}

//export paddle_begin_init_params
114
func paddle_begin_init_params(client C.paddle_pserver_client) C.int {
115
	c := get(client)
116
	if selected := c.BeginInitParams(); selected {
117 118
		return 1
	}
H
Helin Wang 已提交
119
	return C.PSERVER_OK
120 121 122
}

//export paddle_init_param
123
func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, param_config unsafe.Pointer, config_len C.int) C.int {
D
dzhwinter 已提交
124
	et
125

126
	if err != nil {
127
		if err.Error() == pserver.AlreadyInitialized {
H
Helin Wang 已提交
128
			log.Warningf("parameter %s already initialized, treat paddle_init_param as sucessful.\n", name)
H
Helin Wang 已提交
129
			return C.PSERVER_OK
130
		}
H
Helin Wang 已提交
131
		log.Errorln(err)
H
Helin Wang 已提交
132
		return C.PSERVER_ERROR
133 134
	}

H
Helin Wang 已提交
135
	return C.PSERVER_OK
136 137 138
}

//export paddle_finish_init_params
139
func paddle_finish_init_params(client C.paddle_pserver_client) C.int {
140 141 142
	c := get(client)
	err := c.FinishInitParams()
	if err != nil {
143
		if err.Error() == pserver.AlreadyInitialized {
H
Helin Wang 已提交
144
			log.Warningln("parameters already initialized, treat paddle_finish_init_params as sucessful.")
H
Helin Wang 已提交
145
			return C.PSERVER_OK
146 147
		}

H
Helin Wang 已提交
148
		log.Errorln(err)
H
Helin Wang 已提交
149
		return C.PSERVER_ERROR
150 151
	}

H
Helin Wang 已提交
152
	return C.PSERVER_OK
153 154 155
}

//export paddle_send_grads
156
func paddle_send_grads(client C.paddle_pserver_client, grads **C.paddle_gradient, total C.int) C.int {
157 158
	var gs []pserver.Gradient
	for i := 0; i < int(total); i++ {
159
		grad := *(**C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads))))
160 161
		et := pserver.ElementType(grad.element_type)
		name := C.GoString(grad.name)
D
dzhwinter 已提交
162
		gs = append(gs, pserver.Gradient{Name: name, ElementType: et, Content: grad.content, Length: grad.content_len})
163 164 165 166 167
	}

	c := get(client)
	err := c.SendGrads(gs)
	if err != nil {
H
Helin Wang 已提交
168
		log.Errorln(err)
H
Helin Wang 已提交
169
		return C.PSERVER_ERROR
170 171
	}

H
Helin Wang 已提交
172
	return C.PSERVER_OK
173 174 175
}

//export paddle_get_params
176
func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, total C.int) C.int {
177 178
	var ns []string
	for i := 0; i < int(total); i++ {
179 180
		param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
		ns = append(ns, C.GoString(param.name))
181 182 183 184
	}
	c := get(client)
	ps, err := c.GetParams(ns)
	if err != nil {
H
Helin Wang 已提交
185
		log.Errorln(err)
H
Helin Wang 已提交
186
		return C.PSERVER_ERROR
187 188
	}

189
	if len(ps) != len(ns) {
H
Helin Wang 已提交
190 191 192 193
		pn := make([]string, len(ps))
		for i, p := range ps {
			pn[i] = p.Name
		}
H
Helin Wang 已提交
194
		log.Errorf("pserver returned wrong number of parameters. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", "))
H
Helin Wang 已提交
195
		return C.PSERVER_ERROR
196 197 198 199
	}

	for i := range ps {
		if ns[i] != ps[i].Name {
H
Helin Wang 已提交
200 201 202 203
			pn := make([]string, len(ps))
			for i, p := range ps {
				pn[i] = p.Name
			}
H
Helin Wang 已提交
204
			log.Errorf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", "))
H
Helin Wang 已提交
205
			return C.PSERVER_ERROR
206
		}
207
	}
208

209
	for i := 0; i < int(total); i++ {
210
		p := ps[i]
H
Helin Wang 已提交
211 212 213
		param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))

		if unsafe.Pointer(param) == nullPtr {
H
Helin Wang 已提交
214
			log.Errorln("must pre-allocate parameter.")
H
Helin Wang 已提交
215
			return C.PSERVER_ERROR
216 217 218
		}

		if unsafe.Pointer(param.content) != nullPtr {
D
dzhwinter 已提交
219
			if int(param.content_len) != p.Length {
H
Helin Wang 已提交
220
				log.Errorf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content))
221
				return C.PSERVER_ERROR
222 223 224
			}
		}

D
dzhwinter 已提交
225 226
		C.memcpy(unsafe.Pointer(param.content), unsafe.Pointer(p.Content), C.size_t(p.Length))
		param.content_len = C.int(p.Length)
227 228 229
		param.element_type = C.paddle_element_type(p.ElementType)
	}

H
Helin Wang 已提交
230
	return C.PSERVER_OK
231 232 233
}

//export paddle_save_model
234
func paddle_save_model(client C.paddle_pserver_client, path *C.char) C.int {
235 236
	p := C.GoString(path)
	c := get(client)
237
	err := c.Save(p)
238
	if err != nil {
H
Helin Wang 已提交
239
		log.Errorln(err)
H
Helin Wang 已提交
240
		return C.PSERVER_ERROR
241 242
	}

H
Helin Wang 已提交
243
	return C.PSERVER_OK
244 245 246
}

func main() {} // Required but ignored