diff --git a/go/master/c/client.go b/go/master/c/client.go index 220184c3af73118a8dc6488370ebf9ea4d8d6137..8a437eb2238c64a2def03e9063e87b73cd2377c6 100644 --- a/go/master/c/client.go +++ b/go/master/c/client.go @@ -1,6 +1,12 @@ package main /* +#include +#include +#include + +#define PADDLE_MASTER_OK 0 +#define PADDLE_MASTER_ERROR -1 typedef int paddle_master_client; */ @@ -14,6 +20,7 @@ import ( "github.com/PaddlePaddle/Paddle/go/master" ) +var nullPtr = unsafe.Pointer(uintptr(0)) var mu sync.Mutex var handleMap = make(map[C.paddle_master_client]*master.Client) var curHandle C.paddle_master_client @@ -47,17 +54,16 @@ func (a addresser) Address() string { return string(a) } -//paddle_new_master_client -func paddle_new_master_client(addr *C.char, buf_size C.int) C.paddle_master_client { +//export paddle_new_master_client +func paddle_new_master_client(addr *C.char) C.paddle_master_client { a := C.GoString(addr) - c := master.NewClient(addresser(a), int(buf_size)) + c := master.NewClient(addresser(a)) return add(c) } -//export paddle_new_etcd_master_client -func paddle_new_etcd_master_client(etcd_addr *C.char) C.paddle_master_client { - // TODO(helin): fault tolerant master client using etcd. - panic("not implemented.") +//export paddle_release_master_client +func paddle_release_master_client(client C.paddle_master_client) { + remove(client) } //export paddle_set_dataset @@ -65,17 +71,40 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int c := get(client) var paths []string for i := 0; i < int(size); i++ { - ptr := (**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(path)) + uintptr(size))) + ptr := (**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(path)) + uintptr(i)*unsafe.Sizeof(*path))) str := C.GoString(*ptr) paths = append(paths, str) } err := c.SetDataset(paths) if err != nil { log.Println(err) - return -1 + return C.PADDLE_MASTER_ERROR + } + + return C.PADDLE_MASTER_OK +} + +//export paddle_next_record +func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { + c := get(client) + r := c.NextRecord() + if len(r) == 0 { + *record = (*C.uchar)(nullPtr) + return 0 } - return 0 + size := C.size_t(len(r)) + *record = (*C.uchar)(C.malloc(size)) + C.memcpy(unsafe.Pointer(*record), unsafe.Pointer(&r[0]), size) + return C.int(size) +} + +//export mem_free +func mem_free(p unsafe.Pointer) { + // "free" may be a better name for this function, but doing so + // will cause calling any function of this library from Python + // ctypes hanging. + C.free(p) } func main() {} diff --git a/go/master/client.go b/go/master/client.go index 1c8a8d73d0ca755f6ce13d0e2cc494a3443e3c50..73c945ddc5a616a46a83ad9c704f492341344663 100644 --- a/go/master/client.go +++ b/go/master/client.go @@ -21,13 +21,10 @@ type Client struct { } // NewClient creates a new Client. -// -// bufSize is the record buffer size. NextRecord will read from the -// buffer. -func NewClient(addr Addresser, bufSize int) *Client { +func NewClient(addr Addresser) *Client { c := &Client{} c.conn = connection.New() - c.ch = make(chan []byte, bufSize) + c.ch = make(chan []byte) go c.monitorMaster(addr) go c.getRecords() return c @@ -53,11 +50,19 @@ func (c *Client) getRecords() { c.ch <- s.Record() } + if s.Err() != nil { + log.Println(err, chunk.Path) + } + err = f.Close() if err != nil { log.Println(err) } } + + // We treat a task as finished whenever the last data + // instance of the task is read. This is not exactly + // correct, but a reasonable approximation. c.taskFinished(t.ID) } } diff --git a/go/master/client_test.go b/go/master/client_test.go index 2b3f873ecf3a650cd91d1d9c20b414b05bbb0cd6..02751aeb301dfe000e37f757cf6503f42322db0a 100644 --- a/go/master/client_test.go +++ b/go/master/client_test.go @@ -60,7 +60,7 @@ func TestNextRecord(t *testing.T) { w.Close() f.Close() - c := master.NewClient(master.TestAddresser(fmt.Sprintf(":%d", p)), 10) + c := master.NewClient(master.TestAddresser(fmt.Sprintf(":%d", p))) c.SetDataset([]string{path}) for pass := 0; pass < 50; pass++ { diff --git a/go/master/python/.gitignore b/go/master/python/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..704d307510b7c4cf1acaba7125e54458cf926031 --- /dev/null +++ b/go/master/python/.gitignore @@ -0,0 +1 @@ +*.whl diff --git a/go/master/python/build.sh b/go/master/python/build.sh new file mode 100755 index 0000000000000000000000000000000000000000..e3dbd7b0bc322a6a12f21b8708623919c9c4260f --- /dev/null +++ b/go/master/python/build.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +go build -buildmode=c-shared ../c && rm c.h && mv c paddle_master/libmaster.so +pip wheel . diff --git a/go/master/python/paddle_master/.gitignore b/go/master/python/paddle_master/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..4cd9921259090479f917bcc7537db5ca71d0c8ee --- /dev/null +++ b/go/master/python/paddle_master/.gitignore @@ -0,0 +1,2 @@ +*.so +*.pyc diff --git a/go/master/python/paddle_master/__init__.py b/go/master/python/paddle_master/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c8975b5d4a33cbecb4fa5a144bc610c36591d629 --- /dev/null +++ b/go/master/python/paddle_master/__init__.py @@ -0,0 +1,3 @@ +from client import * + +__all__ = ['client'] diff --git a/go/master/python/paddle_master/client.py b/go/master/python/paddle_master/client.py new file mode 100644 index 0000000000000000000000000000000000000000..ea2942f6969dc95c2ea423c3018b30962ffba69c --- /dev/null +++ b/go/master/python/paddle_master/client.py @@ -0,0 +1,39 @@ +import ctypes +import os + +path = os.path.join(os.path.dirname(__file__), "libmaster.so") +lib = ctypes.cdll.LoadLibrary(path) + + +class client(object): + """ + client is a client to the master server. + """ + + def __init__(self, addr, buf_size): + self.c = lib.paddle_new_master_client(addr, buf_size) + + def close(self): + lib.paddle_release_master_client(self.c) + self.c = None + + def set_dataset(self, paths): + holder_type = ctypes.c_char_p * len(paths) + holder = holder_type() + print paths + for idx, path in enumerate(paths): + c_ptr = ctypes.c_char_p(path) + holder[idx] = c_ptr + lib.paddle_set_dataset(self.c, holder, len(paths)) + + def next_record(self): + p = ctypes.c_char_p() + ret = ctypes.pointer(p) + size = lib.paddle_next_record(self.c, ret) + if size == 0: + # empty record + return "" + record = ret.contents.value[:size] + # memory created from C should be freed. + lib.mem_free(ret.contents) + return record diff --git a/go/master/python/setup.py b/go/master/python/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..d7b6e9ecab69cc6372122d62a07ab435a11e2e7c --- /dev/null +++ b/go/master/python/setup.py @@ -0,0 +1,19 @@ +from setuptools import setup, Distribution + + +class BinaryDistribution(Distribution): + def has_ext_modules(foo): + return True + + +setup( + name='paddle_master', + version='0.1', + description='The client of the master server of PaddlePaddle.', + url='https://github.com/PaddlePaddle/Paddle/go/master/python', + author='PaddlePaddle Authors', + author_email='paddle-dev@baidu.com', + license='Apache 2.0', + packages=['paddle_master'], + package_data={'master': ['libmaster.so'], }, + distclass=BinaryDistribution) diff --git a/go/pserver/cclient/cclient.go b/go/pserver/cclient/cclient.go index 4476e762dad04009833421056aa5a49efd44ddaa..3e074a9f2aa919369fc8ae1a2f90bb619b540898 100644 --- a/go/pserver/cclient/cclient.go +++ b/go/pserver/cclient/cclient.go @@ -1,7 +1,6 @@ package main /* -#include #include typedef enum { PADDLE_ELEMENT_TYPE_INT32 = 0, @@ -223,12 +222,12 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, if unsafe.Pointer(param) == nullPtr { log.Println("must pre-allocate parameter.") return C.PSERVER_ERROR - } else { - if unsafe.Pointer(param.content) != nullPtr { - if int(param.content_len) != len(p.Content) { - log.Printf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content)) - return C.PSERVER_ERROR - } + } + + if unsafe.Pointer(param.content) != nullPtr { + if int(param.content_len) != len(p.Content) { + log.Printf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content)) + return C.PSERVER_ERROR } }