提交 7b9080ef 编写于 作者: H Helin Wang

Implement master client, cgo and Python part

上级 fa5c3f1f
package main
/*
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1
typedef int paddle_master_client;
*/
......@@ -14,6 +20,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/master"
)
var nullPtr = unsafe.Pointer(uintptr(0))
var mu sync.Mutex
var handleMap = make(map[C.paddle_master_client]*master.Client)
var curHandle C.paddle_master_client
......@@ -47,17 +54,16 @@ func (a addresser) Address() string {
return string(a)
}
//paddle_new_master_client
func paddle_new_master_client(addr *C.char, buf_size C.int) C.paddle_master_client {
//export paddle_new_master_client
func paddle_new_master_client(addr *C.char) C.paddle_master_client {
a := C.GoString(addr)
c := master.NewClient(addresser(a), int(buf_size))
c := master.NewClient(addresser(a))
return add(c)
}
//export paddle_new_etcd_master_client
func paddle_new_etcd_master_client(etcd_addr *C.char) C.paddle_master_client {
// TODO(helin): fault tolerant master client using etcd.
panic("not implemented.")
//export paddle_release_master_client
func paddle_release_master_client(client C.paddle_master_client) {
remove(client)
}
//export paddle_set_dataset
......@@ -65,17 +71,40 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
c := get(client)
var paths []string
for i := 0; i < int(size); i++ {
ptr := (**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(path)) + uintptr(size)))
ptr := (**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(path)) + uintptr(i)*unsafe.Sizeof(*path)))
str := C.GoString(*ptr)
paths = append(paths, str)
}
err := c.SetDataset(paths)
if err != nil {
log.Println(err)
return -1
return C.PADDLE_MASTER_ERROR
}
return C.PADDLE_MASTER_OK
}
//export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client)
r := c.NextRecord()
if len(r) == 0 {
*record = (*C.uchar)(nullPtr)
return 0
}
size := C.size_t(len(r))
*record = (*C.uchar)(C.malloc(size))
C.memcpy(unsafe.Pointer(*record), unsafe.Pointer(&r[0]), size)
return C.int(size)
}
//export mem_free
func mem_free(p unsafe.Pointer) {
// "free" may be a better name for this function, but doing so
// will cause calling any function of this library from Python
// ctypes hanging.
C.free(p)
}
func main() {}
......@@ -21,13 +21,10 @@ type Client struct {
}
// NewClient creates a new Client.
//
// bufSize is the record buffer size. NextRecord will read from the
// buffer.
func NewClient(addr Addresser, bufSize int) *Client {
func NewClient(addr Addresser) *Client {
c := &Client{}
c.conn = connection.New()
c.ch = make(chan []byte, bufSize)
c.ch = make(chan []byte)
go c.monitorMaster(addr)
go c.getRecords()
return c
......@@ -53,11 +50,19 @@ func (c *Client) getRecords() {
c.ch <- s.Record()
}
if s.Err() != nil {
log.Println(err, chunk.Path)
}
err = f.Close()
if err != nil {
log.Println(err)
}
}
// We treat a task as finished whenever the last data
// instance of the task is read. This is not exactly
// correct, but a reasonable approximation.
c.taskFinished(t.ID)
}
}
......
......@@ -60,7 +60,7 @@ func TestNextRecord(t *testing.T) {
w.Close()
f.Close()
c := master.NewClient(master.TestAddresser(fmt.Sprintf(":%d", p)), 10)
c := master.NewClient(master.TestAddresser(fmt.Sprintf(":%d", p)))
c.SetDataset([]string{path})
for pass := 0; pass < 50; pass++ {
......
#!/bin/bash
go build -buildmode=c-shared ../c && rm c.h && mv c paddle_master/libmaster.so
pip wheel .
from client import *
__all__ = ['client']
import ctypes
import os
path = os.path.join(os.path.dirname(__file__), "libmaster.so")
lib = ctypes.cdll.LoadLibrary(path)
class client(object):
"""
client is a client to the master server.
"""
def __init__(self, addr, buf_size):
self.c = lib.paddle_new_master_client(addr, buf_size)
def close(self):
lib.paddle_release_master_client(self.c)
self.c = None
def set_dataset(self, paths):
holder_type = ctypes.c_char_p * len(paths)
holder = holder_type()
print paths
for idx, path in enumerate(paths):
c_ptr = ctypes.c_char_p(path)
holder[idx] = c_ptr
lib.paddle_set_dataset(self.c, holder, len(paths))
def next_record(self):
p = ctypes.c_char_p()
ret = ctypes.pointer(p)
size = lib.paddle_next_record(self.c, ret)
if size == 0:
# empty record
return ""
record = ret.contents.value[:size]
# memory created from C should be freed.
lib.mem_free(ret.contents)
return record
from setuptools import setup, Distribution
class BinaryDistribution(Distribution):
def has_ext_modules(foo):
return True
setup(
name='paddle_master',
version='0.1',
description='The client of the master server of PaddlePaddle.',
url='https://github.com/PaddlePaddle/Paddle/go/master/python',
author='PaddlePaddle Authors',
author_email='paddle-dev@baidu.com',
license='Apache 2.0',
packages=['paddle_master'],
package_data={'master': ['libmaster.so'], },
distclass=BinaryDistribution)
package main
/*
#include <stdlib.h>
#include <string.h>
typedef enum {
PADDLE_ELEMENT_TYPE_INT32 = 0,
......@@ -223,14 +222,14 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
if unsafe.Pointer(param) == nullPtr {
log.Println("must pre-allocate parameter.")
return C.PSERVER_ERROR
} else {
}
if unsafe.Pointer(param.content) != nullPtr {
if int(param.content_len) != len(p.Content) {
log.Printf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content))
return C.PSERVER_ERROR
}
}
}
C.memcpy(unsafe.Pointer(param.content), unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
param.content_len = C.int(len(p.Content))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册