cclient.go 7.7 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
package main

/*
#include <string.h>
typedef enum {
  PADDLE_ELEMENT_TYPE_INT32   = 0,
  PADDLE_ELEMENT_TYPE_UINT32  = 1,
  PADDLE_ELEMENT_TYPE_INT64   = 2,
  PADDLE_ELEMENT_TYPE_UINT64  = 3,
  PADDLE_ELEMENT_TYPE_FLOAT32 = 4,
  PADDLE_ELEMENT_TYPE_FLOAT64 = 5,
} paddle_element_type;

typedef struct {
  char*               name;
  paddle_element_type element_type;
31
  unsigned char*      content;
32 33 34
  int                 content_len;
} paddle_parameter, paddle_gradient;

35
typedef int paddle_pserver_client;
H
Helin Wang 已提交
36 37
#define PSERVER_ERROR -1
#define PSERVER_OK 0
38 39 40 41
*/
import "C"

import (
42
	"strings"
43 44 45
	"sync"
	"unsafe"

46
	"github.com/PaddlePaddle/Paddle/go/pserver"
Q
Qiao Longfei 已提交
47
	"github.com/PaddlePaddle/Paddle/go/pserver/client"
48
	log "github.com/inconshreveable/log15"
49 50
)

51 52 53 54 55 56
func init() {
	log.Root().SetHandler(
		log.LvlFilterHandler(log.LvlWarn, log.CallerStackHandler("%+v", log.StderrHandler)),
	)
}

57
var mu sync.Mutex
Q
Qiao Longfei 已提交
58
var handleMap = make(map[C.paddle_pserver_client]*client.Client)
59
var curHandle C.paddle_pserver_client
60

Q
Qiao Longfei 已提交
61
func add(c *client.Client) C.paddle_pserver_client {
62 63
	mu.Lock()
	defer mu.Unlock()
64
	cli := curHandle
65
	curHandle++
66 67
	handleMap[cli] = c
	return cli
68 69
}

Q
Qiao Longfei 已提交
70
func get(client C.paddle_pserver_client) *client.Client {
71 72 73 74 75
	mu.Lock()
	defer mu.Unlock()
	return handleMap[client]
}

Q
Qiao Longfei 已提交
76
func remove(client C.paddle_pserver_client) *client.Client {
77 78 79 80 81 82 83 84
	mu.Lock()
	defer mu.Unlock()
	h := handleMap[client]
	delete(handleMap, client)
	return h
}

func cArrayToSlice(p unsafe.Pointer, len int) []byte {
85
	if p == nil {
H
Helin Wang 已提交
86 87 88
		return nil
	}

89 90 91 92
	// create a Go clice backed by a C array, reference:
	// https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
	//
	// Go garbage collector will not interact with this data, need
H
Helin Wang 已提交
93
	// to be freed properly.
94 95 96
	return (*[1 << 30]byte)(p)[:len:len]
}

97 98
type selector bool

99 100 101 102 103 104
func (s selector) Select() (bool, error) {
	return bool(s), nil
}

func (s selector) Done() error {
	return nil
105 106
}

Q
Qiao Longfei 已提交
107
type lister []client.Server
108

Q
Qiao Longfei 已提交
109
func (l lister) List() []client.Server {
110 111 112
	return l
}

113
//export paddle_new_pserver_client
114
func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_client {
115 116
	a := C.GoString(addrs)
	as := strings.Split(a, ",")
Q
Qiao Longfei 已提交
117
	servers := make([]client.Server, len(as))
118 119 120 121
	for i := range as {
		servers[i].Index = i
		servers[i].Addr = as[i]
	}
Q
Qiao Longfei 已提交
122
	c := client.NewClient(lister(servers), len(as), selector(selected != 0))
123 124 125
	return add(c)
}

126
//export paddle_new_etcd_pserver_client
127
func paddle_new_etcd_pserver_client(etcdEndpoints *C.char) C.paddle_pserver_client {
H
Helin Wang 已提交
128 129
	addr := C.GoString(etcdEndpoints)
	etcdClient := client.NewEtcd(addr)
130
	c := client.NewClient(etcdClient, etcdClient.Desired(), etcdClient)
Q
Qiao Longfei 已提交
131
	return add(c)
132 133
}

134
//export paddle_pserver_client_release
135
func paddle_pserver_client_release(client C.paddle_pserver_client) {
136
	remove(client)
137 138
}

139 140 141 142 143 144
// paddle_begin_init_params tells trainer if it needs to init the
// parameters.
//
// returns 1 if the trainer needs to init the parameters. 0 if the
// trainer does not need to init the parameters.
//
145
//export paddle_begin_init_params
146
func paddle_begin_init_params(client C.paddle_pserver_client) C.int {
147
	c := get(client)
148 149 150 151 152 153
	selected, err := c.BeginInitParams()
	if err != nil {
		panic(err)
	}

	if selected {
154 155
		return 1
	}
156
	return 0
157 158 159
}

//export paddle_init_param
H
Helin Wang 已提交
160
func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, paramConfig unsafe.Pointer, configLen C.int) C.int {
161 162 163 164 165
	et := pserver.ElementType(param.element_type)
	name := C.GoString(param.name)
	content := cArrayToSlice(unsafe.Pointer(param.content), int(param.content_len))
	pc := pserver.ParameterWithConfig{
		Param:  pserver.Parameter{Name: name, ElementType: et, Content: content},
H
Helin Wang 已提交
166
		Config: cArrayToSlice(paramConfig, int(configLen)),
167 168 169
	}
	c := get(client)
	err := c.InitParam(pc)
170

171
	if err != nil {
172
		if err.Error() == pserver.AlreadyInitialized {
173 174 175 176
			log.Warn(
				"parameter already initialized, treat paddle_init_param as successful.",
				log.Ctx{"parameter": name},
			)
H
Helin Wang 已提交
177
			return C.PSERVER_OK
178
		}
179
		log.Error("error init param", log.Ctx{"error": err})
H
Helin Wang 已提交
180
		return C.PSERVER_ERROR
181 182
	}

H
Helin Wang 已提交
183
	return C.PSERVER_OK
184 185 186
}

//export paddle_finish_init_params
187
func paddle_finish_init_params(client C.paddle_pserver_client) C.int {
188 189 190
	c := get(client)
	err := c.FinishInitParams()
	if err != nil {
191
		if err.Error() == pserver.AlreadyInitialized {
192
			log.Warn("parameters already initialized, treat paddle_finish_init_params as successful.")
H
Helin Wang 已提交
193
			return C.PSERVER_OK
194 195
		}

196
		log.Error("error finish init params", log.Ctx{"error": err})
H
Helin Wang 已提交
197
		return C.PSERVER_ERROR
198 199
	}

H
Helin Wang 已提交
200
	return C.PSERVER_OK
201 202 203
}

//export paddle_send_grads
204
func paddle_send_grads(client C.paddle_pserver_client, grads **C.paddle_gradient, total C.int) C.int {
205 206
	var gs []pserver.Gradient
	for i := 0; i < int(total); i++ {
207
		grad := *(**C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads))))
208 209 210 211 212 213 214 215 216
		et := pserver.ElementType(grad.element_type)
		name := C.GoString(grad.name)
		content := cArrayToSlice(unsafe.Pointer(grad.content), int(grad.content_len))
		gs = append(gs, pserver.Gradient{Name: name, ElementType: et, Content: content})
	}

	c := get(client)
	err := c.SendGrads(gs)
	if err != nil {
217
		log.Error("error send grads", log.Ctx{"error": err})
H
Helin Wang 已提交
218
		return C.PSERVER_ERROR
219 220
	}

H
Helin Wang 已提交
221
	return C.PSERVER_OK
222 223 224
}

//export paddle_get_params
225
func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, total C.int) C.int {
226 227
	var ns []string
	for i := 0; i < int(total); i++ {
228 229
		param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
		ns = append(ns, C.GoString(param.name))
230 231 232 233
	}
	c := get(client)
	ps, err := c.GetParams(ns)
	if err != nil {
234
		log.Error("error get params", log.Ctx{"error": err})
H
Helin Wang 已提交
235
		return C.PSERVER_ERROR
236 237
	}

238
	if len(ps) != len(ns) {
H
Helin Wang 已提交
239 240 241 242
		pn := make([]string, len(ps))
		for i, p := range ps {
			pn[i] = p.Name
		}
243 244 245 246 247 248 249
		log.Error(
			"pserver returned wrong number of parameters.",
			log.Ctx{
				"Requested": strings.Join(pn, ", "),
				"Returned":  strings.Join(ns, ", "),
			},
		)
H
Helin Wang 已提交
250
		return C.PSERVER_ERROR
251 252 253 254
	}

	for i := range ps {
		if ns[i] != ps[i].Name {
H
Helin Wang 已提交
255 256 257 258
			pn := make([]string, len(ps))
			for i, p := range ps {
				pn[i] = p.Name
			}
259 260 261 262 263 264 265
			log.Error(
				"pserver returned wrong parameters, or not in requested order.",
				log.Ctx{
					"Requested": strings.Join(pn, ", "),
					"Returned":  strings.Join(ns, ", "),
				},
			)
H
Helin Wang 已提交
266
			return C.PSERVER_ERROR
267
		}
268
	}
269

270
	for i := 0; i < int(total); i++ {
271
		p := ps[i]
H
Helin Wang 已提交
272 273
		param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))

274
		if unsafe.Pointer(param) == nil {
275
			log.Error("must pre-allocate parameter.")
H
Helin Wang 已提交
276
			return C.PSERVER_ERROR
277 278
		}

279
		if unsafe.Pointer(param.content) != nil {
280
			if int(param.content_len) != len(p.Content) {
281 282 283 284 285 286 287
				log.Error(
					"the pre-allocated content len does not match parameter content len.",
					log.Ctx{
						"Pre-allocated len": param.content_len,
						"Returned len":      len(p.Content),
					},
				)
288
				return C.PSERVER_ERROR
289 290 291
			}
		}

H
Helin Wang 已提交
292
		C.memcpy(unsafe.Pointer(param.content), unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
293 294 295 296
		param.content_len = C.int(len(p.Content))
		param.element_type = C.paddle_element_type(p.ElementType)
	}

H
Helin Wang 已提交
297
	return C.PSERVER_OK
298 299 300
}

func main() {} // Required but ignored