cclient.go 6.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
package main

/*
#include <stdlib.h>
#include <string.h>
typedef enum {
  PADDLE_ELEMENT_TYPE_INT32   = 0,
  PADDLE_ELEMENT_TYPE_UINT32  = 1,
  PADDLE_ELEMENT_TYPE_INT64   = 2,
  PADDLE_ELEMENT_TYPE_UINT64  = 3,
  PADDLE_ELEMENT_TYPE_FLOAT32 = 4,
  PADDLE_ELEMENT_TYPE_FLOAT64 = 5,
} paddle_element_type;

typedef struct {
  char*               name;
  paddle_element_type element_type;
18
  unsigned char*      content;
19 20 21
  int                 content_len;
} paddle_parameter, paddle_gradient;

22
typedef int paddle_pserver_client;
23 24 25 26 27
*/
import "C"

import (
	"log"
28
	"strings"
29 30 31
	"sync"
	"unsafe"

32
	"github.com/PaddlePaddle/Paddle/go/pserver"
33 34
)

H
Helin Wang 已提交
35
var nullPtr = unsafe.Pointer(uintptr(0))
36
var mu sync.Mutex
37 38
var handleMap = make(map[C.paddle_pserver_client]*pserver.Client)
var curHandle C.paddle_pserver_client
39

40
func add(c *pserver.Client) C.paddle_pserver_client {
41 42 43 44 45 46 47 48
	mu.Lock()
	defer mu.Unlock()
	client := curHandle
	curHandle++
	handleMap[client] = c
	return client
}

49
func get(client C.paddle_pserver_client) *pserver.Client {
50 51 52 53 54
	mu.Lock()
	defer mu.Unlock()
	return handleMap[client]
}

55
func remove(client C.paddle_pserver_client) *pserver.Client {
56 57 58 59 60 61 62 63
	mu.Lock()
	defer mu.Unlock()
	h := handleMap[client]
	delete(handleMap, client)
	return h
}

func cArrayToSlice(p unsafe.Pointer, len int) []byte {
H
Helin Wang 已提交
64 65 66 67
	if p == nullPtr {
		return nil
	}

68 69 70 71
	// create a Go clice backed by a C array, reference:
	// https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
	//
	// Go garbage collector will not interact with this data, need
H
Helin Wang 已提交
72
	// to be freed properly.
73 74 75
	return (*[1 << 30]byte)(p)[:len:len]
}

76 77 78 79 80 81 82 83 84 85 86 87
type selector bool

func (s selector) Select() bool {
	return bool(s)
}

type lister []pserver.Server

func (l lister) List() []pserver.Server {
	return l
}

88
//export paddle_new_pserver_client
89
func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_client {
90 91 92 93 94 95 96
	a := C.GoString(addrs)
	as := strings.Split(a, ",")
	servers := make([]pserver.Server, len(as))
	for i := range as {
		servers[i].Index = i
		servers[i].Addr = as[i]
	}
H
Helin Wang 已提交
97
	c := pserver.NewClient(lister(servers), len(as), selector(selected != 0))
98 99 100
	return add(c)
}

101
//export paddle_new_etcd_pserver_client
102
func paddle_new_etcd_pserver_client(etcd_addr *C.char) C.paddle_pserver_client {
103 104 105 106
	// TODO(helin): fault tolerant pserver client using etcd.
	panic("not implemented.")
}

107
//export paddle_pserver_client_release
108
func paddle_pserver_client_release(client C.paddle_pserver_client) {
109
	remove(client)
110 111 112
}

//export paddle_begin_init_params
113
func paddle_begin_init_params(client C.paddle_pserver_client) C.int {
114
	c := get(client)
115
	if selected := c.BeginInitParams(); selected {
116 117 118 119 120 121
		return 1
	}
	return 0
}

//export paddle_init_param
122
func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, param_config unsafe.Pointer, config_len C.int) C.int {
123 124 125 126 127 128 129 130 131
	et := pserver.ElementType(param.element_type)
	name := C.GoString(param.name)
	content := cArrayToSlice(unsafe.Pointer(param.content), int(param.content_len))
	pc := pserver.ParameterWithConfig{
		Param:  pserver.Parameter{Name: name, ElementType: et, Content: content},
		Config: cArrayToSlice(param_config, int(config_len)),
	}
	c := get(client)
	err := c.InitParam(pc)
132

133
	if err != nil {
134
		if err.Error() == pserver.AlreadyInitialized {
H
Helin Wang 已提交
135
			log.Printf("parameter %s already initialized, treat paddle_init_param as sucessful.\n", name)
136 137
			return 0
		}
138 139 140 141 142 143 144 145
		log.Println(err)
		return -1
	}

	return 0
}

//export paddle_finish_init_params
146
func paddle_finish_init_params(client C.paddle_pserver_client) C.int {
147 148 149
	c := get(client)
	err := c.FinishInitParams()
	if err != nil {
150 151 152 153 154
		if err.Error() == pserver.AlreadyInitialized {
			log.Println("parameters already initialized, treat paddle_finish_init_params as sucessful.")
			return 0
		}

155 156 157 158 159 160 161 162
		log.Println(err)
		return -1
	}

	return 0
}

//export paddle_send_grads
163
func paddle_send_grads(client C.paddle_pserver_client, grads *C.paddle_gradient, total C.int) C.int {
164 165
	var gs []pserver.Gradient
	for i := 0; i < int(total); i++ {
H
Helin Wang 已提交
166
		grad := (*C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads))))
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
		et := pserver.ElementType(grad.element_type)
		name := C.GoString(grad.name)
		content := cArrayToSlice(unsafe.Pointer(grad.content), int(grad.content_len))
		gs = append(gs, pserver.Gradient{Name: name, ElementType: et, Content: content})
	}

	c := get(client)
	err := c.SendGrads(gs)
	if err != nil {
		log.Println(err)
		return -1
	}

	return 0
}

//export paddle_get_params
184
func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, total C.int) C.int {
185 186
	var ns []string
	for i := 0; i < int(total); i++ {
187 188
		param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
		ns = append(ns, C.GoString(param.name))
189 190 191 192 193 194 195 196
	}
	c := get(client)
	ps, err := c.GetParams(ns)
	if err != nil {
		log.Println(err)
		return -1
	}

H
Helin Wang 已提交
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
	names := func() (string, string) {
		retNames := ""
		for _, p := range ps {
			if retNames == "" {
				retNames = p.Name
			} else {
				retNames = ", " + p.Name
			}
		}

		requestedNames := ""
		for _, n := range ns {
			if requestedNames == "" {
				requestedNames = n
			} else {
				requestedNames = ", " + n
			}
		}

		return requestedNames, retNames
	}

219
	if len(ps) != len(ns) {
H
Helin Wang 已提交
220 221
		requestedNames, retNames := names()
		log.Printf("pserver returned wrong number of parameters. Requested: %s, returned: %s.\n", retNames, requestedNames)
222 223 224 225 226
		return -1
	}

	for i := range ps {
		if ns[i] != ps[i].Name {
H
Helin Wang 已提交
227 228
			requestedNames, retNames := names()
			log.Printf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.\n", retNames, requestedNames)
229
			return -1
230
		}
231
	}
232

233
	for i := 0; i < int(total); i++ {
234
		p := ps[i]
H
Helin Wang 已提交
235 236 237
		param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))

		if unsafe.Pointer(param) == nullPtr {
H
Helin Wang 已提交
238
			log.Println("must pre-allocate parameter.")
239
			return -1
H
Helin Wang 已提交
240
		} else {
H
Helin Wang 已提交
241
			if unsafe.Pointer(param.content) != nullPtr {
242
				if int(param.content_len) != len(p.Content) {
H
Helin Wang 已提交
243
					log.Printf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content))
244
					return -1
H
Helin Wang 已提交
245
				}
246 247 248
			}
		}

H
Helin Wang 已提交
249
		C.memcpy(unsafe.Pointer(param.content), unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
250 251 252 253 254 255 256 257
		param.content_len = C.int(len(p.Content))
		param.element_type = C.paddle_element_type(p.ElementType)
	}

	return 0
}

//export paddle_save_model
258
func paddle_save_model(client C.paddle_pserver_client, path *C.char) C.int {
259 260
	p := C.GoString(path)
	c := get(client)
261
	err := c.Save(p)
262 263 264 265 266 267 268 269 270
	if err != nil {
		log.Println(err)
		return -1
	}

	return 0
}

func main() {} // Required but ignored