提交 1a12720b 编写于 作者: H helinwang 提交者: GitHub

Merge pull request #2468 from helinwang/master_dispatch

Implement master client for reading training tasks
...@@ -126,6 +126,7 @@ endif(WITH_GPU) ...@@ -126,6 +126,7 @@ endif(WITH_GPU)
add_subdirectory(proto) add_subdirectory(proto)
add_subdirectory(paddle) add_subdirectory(paddle)
add_subdirectory(go/master/c)
add_subdirectory(python) add_subdirectory(python)
add_subdirectory(go/pserver/cclient) add_subdirectory(go/pserver/cclient)
......
...@@ -26,27 +26,23 @@ function(GO_LIBRARY NAME BUILD_TYPE) ...@@ -26,27 +26,23 @@ function(GO_LIBRARY NAME BUILD_TYPE)
# automatically get all dependencies specified in the source code # automatically get all dependencies specified in the source code
# for given target. # for given target.
add_custom_target(goGet env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get -d ${rel}/...) add_custom_target(${NAME}_goGet env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get -d ${rel}/...)
# make a symlink that references Paddle inside $GOPATH, so go get # make a symlink that references Paddle inside $GOPATH, so go get
# will use the local changes in Paddle rather than checkout Paddle # will use the local changes in Paddle rather than checkout Paddle
# in github. # in github.
add_custom_target(copyPaddle add_custom_target(${NAME}_copyPaddle
COMMAND rm -rf ${PADDLE_IN_GOPATH}/Paddle COMMAND rm -rf ${PADDLE_IN_GOPATH}/Paddle
COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH}/Paddle) COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH}/Paddle)
add_dependencies(goGet copyPaddle) add_dependencies(${NAME}_goGet ${NAME}_copyPaddle)
add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE} COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
-gcflags=-shared -asmflags=-shared -installsuffix=_shared -a
-o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}" -o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}"
${CMAKE_GO_FLAGS} ${GO_SOURCE} ${CMAKE_GO_FLAGS} ${GO_SOURCE}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN}) add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN})
add_dependencies(${NAME} goGet) add_dependencies(${NAME} ${NAME}_goGet)
if(NOT BUILD_TYPE STREQUAL "STATIC")
install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME} DESTINATION bin)
endif()
endfunction(GO_LIBRARY) endfunction(GO_LIBRARY)
...@@ -2,9 +2,10 @@ package connection ...@@ -2,9 +2,10 @@ package connection
import ( import (
"errors" "errors"
"log"
"net/rpc" "net/rpc"
"sync" "sync"
log "github.com/sirupsen/logrus"
) )
// TODO(helin): add TCP re-connect logic // TODO(helin): add TCP re-connect logic
...@@ -65,7 +66,7 @@ func (c *Conn) Connect(addr string) error { ...@@ -65,7 +66,7 @@ func (c *Conn) Connect(addr string) error {
} else { } else {
err := client.Close() err := client.Close()
if err != nil { if err != nil {
log.Println(err) log.Errorln(err)
} }
return errors.New("client already set from a concurrent goroutine") return errors.New("client already set from a concurrent goroutine")
......
cmake_minimum_required(VERSION 3.0)
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake")
project(cxx_go C Go)
include(golang)
include(flags)
set(MASTER_LIB_NAME "paddle_master")
go_library(${MASTER_LIB_NAME} SHARED)
if(PROJ_ROOT)
add_custom_command(OUTPUT ${PROJ_ROOT}/python/paddle/v2/master/lib${MASTER_LIB_NAME}.so
COMMAND rm ${CMAKE_CURRENT_BINARY_DIR}/lib${MASTER_LIB_NAME}.h
COMMAND cp ${CMAKE_CURRENT_BINARY_DIR}/lib${MASTER_LIB_NAME}.so ${PROJ_ROOT}/python/paddle/v2/master/
DEPENDS ${MASTER_LIB_NAME})
add_custom_target(paddle_master_shared ALL DEPENDS ${PROJ_ROOT}/python/paddle/v2/master/lib${MASTER_LIB_NAME}.so)
endif(PROJ_ROOT)
package main
/*
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1
typedef int paddle_master_client;
*/
import "C"
import (
"sync"
"unsafe"
"github.com/PaddlePaddle/Paddle/go/master"
log "github.com/sirupsen/logrus"
)
var nullPtr = unsafe.Pointer(uintptr(0))
var mu sync.Mutex
var handleMap = make(map[C.paddle_master_client]*master.Client)
var curHandle C.paddle_master_client
func add(c *master.Client) C.paddle_master_client {
mu.Lock()
defer mu.Unlock()
client := curHandle
curHandle++
handleMap[client] = c
return client
}
func get(client C.paddle_master_client) *master.Client {
mu.Lock()
defer mu.Unlock()
return handleMap[client]
}
func remove(client C.paddle_master_client) *master.Client {
mu.Lock()
defer mu.Unlock()
h := handleMap[client]
delete(handleMap, client)
return h
}
type addresser string
func (a addresser) Address() string {
return string(a)
}
//export paddle_new_master_client
func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client {
a := C.GoString(addr)
c := master.NewClient(addresser(a), bufSize)
return add(c)
}
//export paddle_release_master_client
func paddle_release_master_client(client C.paddle_master_client) {
remove(client)
}
//export paddle_set_dataset
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int {
c := get(client)
var paths []string
for i := 0; i < int(size); i++ {
ptr := (**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(path)) + uintptr(i)*unsafe.Sizeof(*path)))
str := C.GoString(*ptr)
paths = append(paths, str)
}
err := c.SetDataset(paths)
if err != nil {
log.Errorln(err)
return C.PADDLE_MASTER_ERROR
}
return C.PADDLE_MASTER_OK
}
//export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client)
r := c.NextRecord()
if len(r) == 0 {
*record = (*C.uchar)(nullPtr)
return 0
}
size := C.size_t(len(r))
*record = (*C.uchar)(C.malloc(size))
C.memcpy(unsafe.Pointer(*record), unsafe.Pointer(&r[0]), size)
return C.int(size)
}
//export mem_free
func mem_free(p unsafe.Pointer) {
// "free" may be a better name for this function, but doing so
// will cause calling any function of this library from Python
// ctypes hanging.
C.free(p)
}
func main() {}
package master package master
import ( import (
"log" "os"
"time" "time"
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
log "github.com/sirupsen/logrus"
) )
// Addresser provide the address of the master server. // Addresser provide the address of the master server.
...@@ -15,16 +17,61 @@ type Addresser interface { ...@@ -15,16 +17,61 @@ type Addresser interface {
// Client is the client of the master server. // Client is the client of the master server.
type Client struct { type Client struct {
conn *connection.Conn conn *connection.Conn
ch chan []byte
} }
// NewClient creates a new Client. // NewClient creates a new Client.
func NewClient(addr Addresser) *Client { //
// bufSize is the record buffer size. NextRecord will read from this
// buffer.
func NewClient(addr Addresser, bufSize int) *Client {
c := &Client{} c := &Client{}
c.conn = connection.New() c.conn = connection.New()
c.ch = make(chan []byte, bufSize)
go c.monitorMaster(addr) go c.monitorMaster(addr)
go c.getRecords()
return c return c
} }
func (c *Client) getRecords() {
for {
t, err := c.getTask()
if err != nil {
// TODO(helin): wait before move on with next
// getTask call.
log.Errorln(err)
continue
}
for _, chunk := range t.Chunks {
f, err := os.Open(chunk.Path)
if err != nil {
log.Errorln(err)
continue
}
s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1)
for s.Scan() {
c.ch <- s.Record()
}
if s.Err() != nil {
log.Errorln(err, chunk.Path)
}
err = f.Close()
if err != nil {
log.Errorln(err)
}
}
// We treat a task as finished whenever the last data
// instance of the task is read. This is not exactly
// correct, but a reasonable approximation.
c.taskFinished(t.ID)
}
}
func (c *Client) monitorMaster(addr Addresser) { func (c *Client) monitorMaster(addr Addresser) {
lastMaster := "" lastMaster := ""
monitor := func() { monitor := func() {
...@@ -35,12 +82,12 @@ func (c *Client) monitorMaster(addr Addresser) { ...@@ -35,12 +82,12 @@ func (c *Client) monitorMaster(addr Addresser) {
if curMaster == "" { if curMaster == "" {
err := c.conn.Close() err := c.conn.Close()
if err != nil { if err != nil {
log.Println(err) log.Errorln(err)
} }
} else { } else {
err := c.conn.Connect(curMaster) err := c.conn.Connect(curMaster)
if err != nil { if err != nil {
log.Println(err) log.Errorln(err)
// connect to addr failed, set // connect to addr failed, set
// to last known addr in order // to last known addr in order
...@@ -69,14 +116,22 @@ func (c *Client) SetDataset(globPaths []string) error { ...@@ -69,14 +116,22 @@ func (c *Client) SetDataset(globPaths []string) error {
return c.conn.Call("Service.SetDataset", globPaths, nil) return c.conn.Call("Service.SetDataset", globPaths, nil)
} }
// GetTask gets a new task from the master server. // getTask gets a new task from the master server.
func (c *Client) GetTask() (Task, error) { func (c *Client) getTask() (Task, error) {
var t Task var t Task
err := c.conn.Call("Service.GetTask", 0, &t) err := c.conn.Call("Service.GetTask", 0, &t)
return t, err return t, err
} }
// TaskFinished tells the master server a task is finished. // TaskFinished tells the master server a task is finished.
func (c *Client) TaskFinished(taskID int) error { func (c *Client) taskFinished(taskID int) error {
return c.conn.Call("Service.TaskFinished", taskID, nil) return c.conn.Call("Service.TaskFinished", taskID, nil)
} }
// NextRecord returns next record in the dataset.
//
// NextRecord will block until the next record is available. It is
// thread-safe.
func (c *Client) NextRecord() []byte {
return <-c.ch
}
package master
import (
"fmt"
"net"
"net/http"
"net/rpc"
"os"
"strconv"
"strings"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
)
const (
totalTask = 20
chunkPerTask = 10
)
func init() {
log.SetLevel(log.ErrorLevel)
}
type TestAddresser string
func (a TestAddresser) Address() string {
return string(a)
}
func TestGetFinishTask(t *testing.T) {
const path = "/tmp/master_client_test_0"
l, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
ss := strings.Split(l.Addr().String(), ":")
p, err := strconv.Atoi(ss[len(ss)-1])
if err != nil {
panic(err)
}
go func(l net.Listener) {
s := NewService(chunkPerTask, time.Second, 1)
server := rpc.NewServer()
err := server.Register(s)
if err != nil {
panic(err)
}
mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server)
err = http.Serve(l, mux)
if err != nil {
panic(err)
}
}(l)
f, err := os.Create(path)
if err != nil {
panic(err)
}
for i := 0; i < totalTask*chunkPerTask; i++ {
w := recordio.NewWriter(f, -1, -1)
w.Write(nil)
// call Close to force RecordIO writing a chunk.
w.Close()
}
f.Close()
// Manually intialize client to avoid calling c.getRecords()
c := &Client{}
c.conn = connection.New()
go c.monitorMaster(TestAddresser(fmt.Sprintf(":%d", p)))
c.SetDataset([]string{path})
checkOnePass := func(i int) {
var tasks []Task
for idx := 0; idx < totalTask; idx++ {
task, err := c.getTask()
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
}
tasks = append(tasks, task)
}
_, err = c.getTask()
if err == nil {
t.Fatalf("Should get error, pass: %d\n", i)
}
err = c.taskFinished(tasks[0].ID)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
}
tasks = tasks[1:]
task, err := c.getTask()
if err != nil {
t.Fatal(err)
}
tasks = append(tasks, task)
for _, task := range tasks {
err = c.taskFinished(task.ID)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
}
}
}
for i := 0; i < 10; i++ {
checkOnePass(i)
}
}
...@@ -11,21 +11,15 @@ import ( ...@@ -11,21 +11,15 @@ import (
"testing" "testing"
"time" "time"
log "github.com/sirupsen/logrus"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
) )
const ( func TestNextRecord(t *testing.T) {
totalTask = 20 const (
chunkPerTask = 10 path = "/tmp/master_client_TestFull"
) total = 50
)
var port int
func init() {
log.SetLevel(log.ErrorLevel)
l, err := net.Listen("tcp", ":0") l, err := net.Listen("tcp", ":0")
if err != nil { if err != nil {
...@@ -37,10 +31,9 @@ func init() { ...@@ -37,10 +31,9 @@ func init() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
port = p
go func(l net.Listener) { go func(l net.Listener) {
s := master.NewService(chunkPerTask, time.Second, 1) s := master.NewService(10, time.Second, 1)
server := rpc.NewServer() server := rpc.NewServer()
err := server.Register(s) err := server.Register(s)
if err != nil { if err != nil {
...@@ -54,67 +47,33 @@ func init() { ...@@ -54,67 +47,33 @@ func init() {
panic(err) panic(err)
} }
}(l) }(l)
}
type addresser string f, err := os.Create(path)
func (a addresser) Address() string {
return string(a)
}
func TestClientFull(t *testing.T) {
const p = "/tmp/master_client_test_0"
f, err := os.Create(p)
if err != nil { if err != nil {
panic(err) panic(err)
} }
for i := 0; i < totalTask*chunkPerTask; i++ { w := recordio.NewWriter(f, -1, -1)
w := recordio.NewWriter(f, -1, -1) for i := 0; i < total; i++ {
w.Write(nil) w.Write([]byte{byte(i)})
// call Close to force RecordIO writing a chunk.
w.Close()
} }
w.Close()
f.Close() f.Close()
c := master.NewClient(addresser(fmt.Sprintf(":%d", port))) c := master.NewClient(master.TestAddresser(fmt.Sprintf(":%d", p)), 10)
c.SetDataset([]string{p}) c.SetDataset([]string{path})
checkOnePass := func(i int) { for pass := 0; pass < 50; pass++ {
var tasks []master.Task received := make(map[byte]bool)
for i := 0; i < totalTask; i++ { for i := 0; i < total; i++ {
task, err := c.GetTask() r := c.NextRecord()
if err != nil { if len(r) != 1 {
t.Fatal(i, err) t.Fatal("Length should be 1.", r)
} }
tasks = append(tasks, task) if received[r[0]] {
} t.Fatal("Received duplicate.", received, r)
_, err = c.GetTask()
if err == nil {
t.Fatal(i, "should get error.")
}
err = c.TaskFinished(tasks[0].ID)
if err != nil {
t.Fatal(err)
}
tasks = tasks[1:]
task, err := c.GetTask()
if err != nil {
t.Fatal(err)
}
tasks = append(tasks, task)
for _, task := range tasks {
err = c.TaskFinished(task.ID)
if err != nil {
t.Fatal(i, err)
} }
received[r[0]] = true
} }
} }
for i := 0; i < 10; i++ {
checkOnePass(i)
}
} }
...@@ -207,16 +207,26 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { ...@@ -207,16 +207,26 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
t.NumTimeout++ t.NumTimeout++
if t.NumTimeout > s.timeoutMax { if t.NumTimeout > s.timeoutMax {
log.Warningf("Task %v failed %d times, discard.\n", t.Task, t.NumTimeout) log.Warningf("Task %v timed out %d times, discard.\n", t.Task, t.NumTimeout)
s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task) s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task)
return return
} }
log.Warningf("Task %v failed %d times, retry.\n", t.Task, t.NumTimeout) log.Warningf("Task %v timed out %d times, retry.\n", t.Task, t.NumTimeout)
s.taskQueues.Todo = append(s.taskQueues.Todo, t) s.taskQueues.Todo = append(s.taskQueues.Todo, t)
} }
} }
// must be called with lock held.
func (s *Service) logFields() log.Fields {
return log.Fields{
"todoLen": len(s.taskQueues.Todo),
"pendingLen": len(s.taskQueues.Pending),
"doneLen": len(s.taskQueues.Done),
"failedLen": len(s.taskQueues.Failed),
}
}
// GetTask gets a new task from the service. // GetTask gets a new task from the service.
func (s *Service) GetTask(dummy int, task *Task) error { func (s *Service) GetTask(dummy int, task *Task) error {
select { select {
...@@ -230,7 +240,7 @@ func (s *Service) GetTask(dummy int, task *Task) error { ...@@ -230,7 +240,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
if len(s.taskQueues.Done) == 0 { if len(s.taskQueues.Done) == 0 {
if len(s.taskQueues.Pending) == 0 { if len(s.taskQueues.Pending) == 0 {
err := errors.New("all task failed") err := errors.New("all task failed")
log.Warningln(err) log.WithFields(s.logFields()).Warningln("All tasks failed.")
return err return err
} }
...@@ -243,12 +253,12 @@ func (s *Service) GetTask(dummy int, task *Task) error { ...@@ -243,12 +253,12 @@ func (s *Service) GetTask(dummy int, task *Task) error {
// in package. So we need to figure out a way // in package. So we need to figure out a way
// for client to check this error correctly. // for client to check this error correctly.
err := errors.New("no more available task") err := errors.New("no more available task")
log.Warningln(err) log.WithFields(s.logFields()).Warningln("No more available task.")
return err return err
} }
s.taskQueues.Todo = s.taskQueues.Done s.taskQueues.Todo = s.taskQueues.Done
s.taskQueues.Done = nil s.taskQueues.Done = nil
log.Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.") log.WithFields(s.logFields()).Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.")
} }
t := s.taskQueues.Todo[0] t := s.taskQueues.Todo[0]
...@@ -261,7 +271,7 @@ func (s *Service) GetTask(dummy int, task *Task) error { ...@@ -261,7 +271,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
} }
*task = t.Task *task = t.Task
log.Infof("Task #%d dispatched\n", task.ID) log.WithFields(s.logFields()).Infof("Task #%d dispatched.", task.ID)
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Epoch)) time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Epoch))
return nil return nil
...@@ -276,12 +286,10 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -276,12 +286,10 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
log.Infof("Task %d finished\n", taskID)
t, ok := s.taskQueues.Pending[taskID] t, ok := s.taskQueues.Pending[taskID]
if !ok { if !ok {
err := errors.New("pending task not found") err := errors.New("pending task not found")
log.Warningln(err) log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID)
return err return err
} }
...@@ -290,8 +298,10 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -290,8 +298,10 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s.taskQueues.Done = append(s.taskQueues.Done, t) s.taskQueues.Done = append(s.taskQueues.Done, t)
delete(s.taskQueues.Pending, taskID) delete(s.taskQueues.Pending, taskID)
log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID)
if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 { if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 {
log.Infoln("No more todo and pending task, start a new pass.") log.WithFields(s.logFields()).Infoln("No more todo and pending task, start a new pass.")
s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...) s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...)
s.taskQueues.Done = nil s.taskQueues.Done = nil
} }
......
package main package main
/* /*
#include <stdlib.h>
#include <string.h> #include <string.h>
typedef enum { typedef enum {
PADDLE_ELEMENT_TYPE_INT32 = 0, PADDLE_ELEMENT_TYPE_INT32 = 0,
...@@ -26,12 +25,12 @@ typedef int paddle_pserver_client; ...@@ -26,12 +25,12 @@ typedef int paddle_pserver_client;
import "C" import "C"
import ( import (
"log"
"strings" "strings"
"sync" "sync"
"unsafe" "unsafe"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/sirupsen/logrus"
) )
var nullPtr = unsafe.Pointer(uintptr(0)) var nullPtr = unsafe.Pointer(uintptr(0))
...@@ -134,10 +133,10 @@ func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, ...@@ -134,10 +133,10 @@ func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter,
if err != nil { if err != nil {
if err.Error() == pserver.AlreadyInitialized { if err.Error() == pserver.AlreadyInitialized {
log.Printf("parameter %s already initialized, treat paddle_init_param as sucessful.\n", name) log.Warningf("parameter %s already initialized, treat paddle_init_param as sucessful.\n", name)
return C.PSERVER_OK return C.PSERVER_OK
} }
log.Println(err) log.Errorln(err)
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
...@@ -150,11 +149,11 @@ func paddle_finish_init_params(client C.paddle_pserver_client) C.int { ...@@ -150,11 +149,11 @@ func paddle_finish_init_params(client C.paddle_pserver_client) C.int {
err := c.FinishInitParams() err := c.FinishInitParams()
if err != nil { if err != nil {
if err.Error() == pserver.AlreadyInitialized { if err.Error() == pserver.AlreadyInitialized {
log.Println("parameters already initialized, treat paddle_finish_init_params as sucessful.") log.Warningln("parameters already initialized, treat paddle_finish_init_params as sucessful.")
return C.PSERVER_OK return C.PSERVER_OK
} }
log.Println(err) log.Errorln(err)
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
...@@ -175,7 +174,7 @@ func paddle_send_grads(client C.paddle_pserver_client, grads **C.paddle_gradient ...@@ -175,7 +174,7 @@ func paddle_send_grads(client C.paddle_pserver_client, grads **C.paddle_gradient
c := get(client) c := get(client)
err := c.SendGrads(gs) err := c.SendGrads(gs)
if err != nil { if err != nil {
log.Println(err) log.Errorln(err)
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
...@@ -192,7 +191,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -192,7 +191,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
c := get(client) c := get(client)
ps, err := c.GetParams(ns) ps, err := c.GetParams(ns)
if err != nil { if err != nil {
log.Println(err) log.Errorln(err)
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
...@@ -201,7 +200,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -201,7 +200,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
for i, p := range ps { for i, p := range ps {
pn[i] = p.Name pn[i] = p.Name
} }
log.Printf("pserver returned wrong number of parameters. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", ")) log.Errorf("pserver returned wrong number of parameters. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", "))
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
...@@ -211,7 +210,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -211,7 +210,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
for i, p := range ps { for i, p := range ps {
pn[i] = p.Name pn[i] = p.Name
} }
log.Printf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", ")) log.Errorf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", "))
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
} }
...@@ -221,14 +220,14 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -221,14 +220,14 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst)))) param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
if unsafe.Pointer(param) == nullPtr { if unsafe.Pointer(param) == nullPtr {
log.Println("must pre-allocate parameter.") log.Errorln("must pre-allocate parameter.")
return C.PSERVER_ERROR return C.PSERVER_ERROR
} else { }
if unsafe.Pointer(param.content) != nullPtr {
if int(param.content_len) != len(p.Content) { if unsafe.Pointer(param.content) != nullPtr {
log.Printf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content)) if int(param.content_len) != len(p.Content) {
return C.PSERVER_ERROR log.Errorf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content))
} return C.PSERVER_ERROR
} }
} }
...@@ -246,7 +245,7 @@ func paddle_save_model(client C.paddle_pserver_client, path *C.char) C.int { ...@@ -246,7 +245,7 @@ func paddle_save_model(client C.paddle_pserver_client, path *C.char) C.int {
c := get(client) c := get(client)
err := c.Save(p) err := c.Save(p)
if err != nil { if err != nil {
log.Println(err) log.Errorln(err)
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
......
#include <stdio.h> #include <stdio.h>
#include <stdlib.h>
#include "libpaddle_pserver_cclient.h" #include "libpaddle_pserver_cclient.h"
......
...@@ -2,11 +2,11 @@ package pserver ...@@ -2,11 +2,11 @@ package pserver
import ( import (
"hash/fnv" "hash/fnv"
"log"
"sort" "sort"
"time" "time"
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
log "github.com/sirupsen/logrus"
) )
// TODO(helin): add RPC call retry logic // TODO(helin): add RPC call retry logic
...@@ -64,7 +64,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) { ...@@ -64,7 +64,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
if curServers[i].Addr == "" { if curServers[i].Addr == "" {
err := c.pservers[i].Close() err := c.pservers[i].Close()
if err != nil { if err != nil {
log.Println(err) log.Errorln(err)
} }
continue continue
...@@ -72,7 +72,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) { ...@@ -72,7 +72,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
err := c.pservers[i].Connect(curServers[i].Addr) err := c.pservers[i].Connect(curServers[i].Addr)
if err != nil { if err != nil {
log.Println(err) log.Errorln(err)
// connect to addr failed, set // connect to addr failed, set
// to last known addr in order // to last known addr in order
......
...@@ -18,7 +18,7 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in ...@@ -18,7 +18,7 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in
add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp
COMMAND env ${py_env} ${PYTHON_EXECUTABLE} setup.py bdist_wheel COMMAND env ${py_env} ${PYTHON_EXECUTABLE} setup.py bdist_wheel
COMMAND ${CMAKE_COMMAND} -E touch ${OUTPUT_DIR}/.timestamp COMMAND ${CMAKE_COMMAND} -E touch ${OUTPUT_DIR}/.timestamp
DEPENDS gen_proto_py ${PY_FILES} ${external_project_dependencies}) DEPENDS gen_proto_py ${PY_FILES} ${external_project_dependencies} paddle_master_shared)
add_custom_target(paddle_python ALL DEPENDS add_custom_target(paddle_python ALL DEPENDS
${OUTPUT_DIR}/.timestamp) ${OUTPUT_DIR}/.timestamp)
......
...@@ -26,6 +26,7 @@ import evaluator ...@@ -26,6 +26,7 @@ import evaluator
from . import dataset from . import dataset
from . import reader from . import reader
from . import plot from . import plot
from . import master
import attr import attr
import op import op
import pooling import pooling
...@@ -37,9 +38,26 @@ import plot ...@@ -37,9 +38,26 @@ import plot
import image import image
__all__ = [ __all__ = [
'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer', 'optimizer',
'event', 'data_type', 'attr', 'pooling', 'data_feeder', 'dataset', 'reader', 'layer',
'topology', 'networks', 'infer', 'plot', 'evaluator', 'image' 'activation',
'parameters',
'init',
'trainer',
'event',
'data_type',
'attr',
'pooling',
'data_feeder',
'dataset',
'reader',
'topology',
'networks',
'infer',
'plot',
'evaluator',
'image',
'master',
] ]
......
from client import *
__all__ = ['client']
import ctypes
import os
path = os.path.join(os.path.dirname(__file__), "libpaddle_master.so")
lib = ctypes.cdll.LoadLibrary(path)
class client(object):
"""
client is a client to the master server.
"""
def __init__(self, addr, buf_size):
self.c = lib.paddle_new_master_client(addr, buf_size)
def close(self):
lib.paddle_release_master_client(self.c)
self.c = None
def set_dataset(self, paths):
holder_type = ctypes.c_char_p * len(paths)
holder = holder_type()
print paths
for idx, path in enumerate(paths):
c_ptr = ctypes.c_char_p(path)
holder[idx] = c_ptr
lib.paddle_set_dataset(self.c, holder, len(paths))
def next_record(self):
p = ctypes.c_char_p()
ret = ctypes.pointer(p)
size = lib.paddle_next_record(self.c, ret)
if size == 0:
# Empty record
return ""
record = ret.contents.value[:size]
# Memory created from C should be freed.
lib.mem_free(ret.contents)
return record
from setuptools import setup from setuptools import setup
packages=['paddle', packages=['paddle',
'paddle.proto', 'paddle.proto',
'paddle.trainer', 'paddle.trainer',
...@@ -9,7 +8,8 @@ packages=['paddle', ...@@ -9,7 +8,8 @@ packages=['paddle',
'paddle.v2', 'paddle.v2',
'paddle.v2.dataset', 'paddle.v2.dataset',
'paddle.v2.reader', 'paddle.v2.reader',
'paddle.v2.plot'] 'paddle.v2.plot',
'paddle.v2.master']
setup_requires=["requests", setup_requires=["requests",
"numpy", "numpy",
...@@ -25,7 +25,8 @@ setup(name='paddle', ...@@ -25,7 +25,8 @@ setup(name='paddle',
description='Parallel Distributed Deep Learning', description='Parallel Distributed Deep Learning',
install_requires=setup_requires, install_requires=setup_requires,
packages=packages, packages=packages,
package_data={'paddle.v2.master': ['libpaddle_master.so'], },
package_dir={ package_dir={
'': '${CMAKE_CURRENT_SOURCE_DIR}' '': '${CMAKE_CURRENT_SOURCE_DIR}'
} },
) )
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册