提交 22076592 编写于 作者: Y yangyaming

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into ssd_map

...@@ -50,6 +50,7 @@ before_install: ...@@ -50,6 +50,7 @@ before_install:
# protobuf version. # protobuf version.
- pip install numpy wheel 'protobuf==3.1' sphinx==1.5.6 recommonmark sphinx-rtd-theme==0.1.9 virtualenv pre-commit requests==2.9.2 LinkChecker - pip install numpy wheel 'protobuf==3.1' sphinx==1.5.6 recommonmark sphinx-rtd-theme==0.1.9 virtualenv pre-commit requests==2.9.2 LinkChecker
- pip install rarfile - pip install rarfile
- eval "$(GIMME_GO_VERSION=1.8.3 gimme)"
- | - |
function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; } function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; }
script: script:
......
...@@ -127,6 +127,7 @@ endif(WITH_GPU) ...@@ -127,6 +127,7 @@ endif(WITH_GPU)
add_subdirectory(proto) add_subdirectory(proto)
add_subdirectory(paddle) add_subdirectory(paddle)
add_subdirectory(python) add_subdirectory(python)
add_subdirectory(go/pserver/cclient)
if(WITH_DOC) if(WITH_DOC)
add_subdirectory(doc) add_subdirectory(doc)
......
...@@ -74,14 +74,25 @@ typedef enum { ...@@ -74,14 +74,25 @@ typedef enum {
typedef struct { typedef struct {
char* name; char* name;
paddle_element_type element_type; paddle_element_type element_type;
void* content; unsigned char* content;
int content_len; int content_len;
} paddle_parameter, paddle_gradient; } paddle_parameter, paddle_gradient;
typedef struct paddle_pserver_client paddle_pserver_client; typedef int paddle_pserver_client;
paddle_pserver_client* paddle_new_pserver_client(); /**
void paddle_pserver_client_release(paddle_pserver_client* client); * @brief creates a pserver client that talks to etcd for coordination.
*/
paddle_pserver_client paddle_new_etcd_pserver_client(char* etcd_addr);
/**
* @brief creates a pserver client given pserver addresses.
*
* @param pserver_addrs comma-separated pserver addresses.
* @param selected if current pserver client is selected to initialize all parameter servers.
*/
paddle_pserver_client paddle_new_pserver_client(char* pserver_addrs, int selected);
void paddle_pserver_client_release(paddle_pserver_client c);
/** /**
* @brief paddle_begin_init_params begins to initialize parameters on * @brief paddle_begin_init_params begins to initialize parameters on
...@@ -95,7 +106,7 @@ void paddle_pserver_client_release(paddle_pserver_client* client); ...@@ -95,7 +106,7 @@ void paddle_pserver_client_release(paddle_pserver_client* client);
* @return 1 if the trainer is selected to initialize parameter * @return 1 if the trainer is selected to initialize parameter
* servers, otherwise 0. * servers, otherwise 0.
*/ */
int paddle_begin_init_params(paddle_pserver_client* client); int paddle_begin_init_params(paddle_pserver_client client);
/** /**
* @brief paddle_init_param initializes the parameter on parameter * @brief paddle_init_param initializes the parameter on parameter
...@@ -109,7 +120,7 @@ int paddle_begin_init_params(paddle_pserver_client* client); ...@@ -109,7 +120,7 @@ int paddle_begin_init_params(paddle_pserver_client* client);
* @paddle_begin_init_param). Or simply exit the program and wait for * @paddle_begin_init_param). Or simply exit the program and wait for
* the cluster management system to restart the trainer. * the cluster management system to restart the trainer.
*/ */
int paddle_init_param(paddle_pserver_client* client, paddle_parameter param, const unsigned char* param_config_proto, int config_len); int paddle_init_param(paddle_pserver_client client, paddle_parameter param, const unsigned char* param_config_proto, int config_len);
/** /**
* @brief paddle_finish_init_params tells parameter servers client has * @brief paddle_finish_init_params tells parameter servers client has
...@@ -120,7 +131,7 @@ int paddle_init_param(paddle_pserver_client* client, paddle_parameter param, con ...@@ -120,7 +131,7 @@ int paddle_init_param(paddle_pserver_client* client, paddle_parameter param, con
* @paddle_begin_init_param). Or simply exit the program and wait for * @paddle_begin_init_param). Or simply exit the program and wait for
* the cluster management system to restart the trainer. * the cluster management system to restart the trainer.
*/ */
int paddle_finish_init_params(paddle_pserver_client* client); int paddle_finish_init_params(paddle_pserver_client client);
/** /**
* @brief paddle_send_grads sends gradients to parameter servers for * @brief paddle_send_grads sends gradients to parameter servers for
...@@ -131,7 +142,7 @@ int paddle_finish_init_params(paddle_pserver_client* client); ...@@ -131,7 +142,7 @@ int paddle_finish_init_params(paddle_pserver_client* client);
* @param learning_rate the learning rate for the gradients. * @param learning_rate the learning rate for the gradients.
* @return 0 if successful, otherwise -1. * @return 0 if successful, otherwise -1.
*/ */
int paddle_send_grads(paddle_pserver_client* client, const paddle_gradient* grads, int len); int paddle_send_grads(paddle_pserver_client client, const paddle_gradient* grads, int len);
/** /**
* @brief paddle_get_params gets parameters from parameter servers. * @brief paddle_get_params gets parameters from parameter servers.
...@@ -139,13 +150,15 @@ int paddle_send_grads(paddle_pserver_client* client, const paddle_gradient* grad ...@@ -139,13 +150,15 @@ int paddle_send_grads(paddle_pserver_client* client, const paddle_gradient* grad
* paddle_get_params will block until parameters are initialized on * paddle_get_params will block until parameters are initialized on
* the parameter servers. * the parameter servers.
* *
* @param names the array of names of the parameters to get. * @param dst the destination array of parameter pointers to save to.
* @param dst the destination array of parameters to save to. * The parameter pointer must be pre-popullated with required parameter name,
* and the content of parameter must be pre-allocated of the size of required
* parameter on pserver.
* @param len the length of the names array and the paddle_parameter * @param len the length of the names array and the paddle_parameter
* array. * array.
* @return 0 if successful, otherwise -1. * @return 0 if successful, otherwise -1.
*/ */
int paddle_get_params(paddle_pserver_client* client, const char** names, paddle_parameter* dst, int len); int paddle_get_params(paddle_pserver_client client, paddle_parameter** dst, int len);
/** /**
* @brief paddle_save_model indicates parameters to save the parameter * @brief paddle_save_model indicates parameters to save the parameter
...@@ -154,5 +167,5 @@ int paddle_get_params(paddle_pserver_client* client, const char** names, paddle_ ...@@ -154,5 +167,5 @@ int paddle_get_params(paddle_pserver_client* client, const char** names, paddle_
* @param path the path to save parameters. * @param path the path to save parameters.
* @return 0 if successful, otherwise -1. * @return 0 if successful, otherwise -1.
*/ */
int paddle_save_model(paddle_pserver_client* client, const char* path); int paddle_save_model(paddle_pserver_client client, const char* path);
``` ```
# Design Doc: Remote Parameter Updater for Cluster Train
For an overview of distribute training, please refer to [distributed training design doc](README.md). In this design doc, we will discuss the parameter updater that will use parameter server cclient [The Client Library of Parameter Server Design Doc](pserver_client.md) to manage and update parameters.
## Parameter Updater
Parameter Updater is used by trainer to manage and update parameter, there are mainly two kind of parameter updater: local and remote, since this design is for cluster train, we will only discuss remote parameter updater here.
### Remote Parameter Updater
Remote Parameter Updater manage parameters through remote parameter server with the client that communicate with pserver([The Client Library of Parameter Server Design Doc](pserver_client.md))
In PaddlePaddle Python V2 API, trainer is implemented in python, and the trainer will hold a instance of parameter updater and call it's functions directly. In this design, we will also expose the api of RemoteParameterUpdater to python with swig.
#### Sparse Remote Parameter Updater
Since we will only implement dense parameter management new, the mechanism for sparse parameter will be discussed in next stage.
### Interface Design
TBD
...@@ -17,7 +17,7 @@ function(GO_LIBRARY NAME BUILD_TYPE) ...@@ -17,7 +17,7 @@ function(GO_LIBRARY NAME BUILD_TYPE)
endif() endif()
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go") file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
file(RELATIVE_PATH rel ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}) file(RELATIVE_PATH rel ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR})
# find Paddle directory. # find Paddle directory.
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY) get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
...@@ -32,11 +32,13 @@ function(GO_LIBRARY NAME BUILD_TYPE) ...@@ -32,11 +32,13 @@ function(GO_LIBRARY NAME BUILD_TYPE)
# will use the local changes in Paddle rather than checkout Paddle # will use the local changes in Paddle rather than checkout Paddle
# in github. # in github.
add_custom_target(copyPaddle add_custom_target(copyPaddle
COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH}) COMMAND rm -rf ${PADDLE_IN_GOPATH}/Paddle
COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH}/Paddle)
add_dependencies(goGet copyPaddle) add_dependencies(goGet copyPaddle)
add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE} COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
-gcflags=-shared -asmflags=-shared -installsuffix=_shared -a
-o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}" -o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}"
${CMAKE_GO_FLAGS} ${GO_SOURCE} ${CMAKE_GO_FLAGS} ${GO_SOURCE}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
......
package main package main
import ( import (
"fmt"
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
"os"
"path/filepath"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/namsral/flag" "github.com/namsral/flag"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/recordio"
) )
func main() { func main() {
port := flag.Int("port", 8080, "port of the master server.") port := flag.Int("port", 8080, "port of the master server.")
dataset := flag.String("training_dataset", "", "dataset: comma separated path to RecordIO paths, supports golb patterns.")
faultTolerance := flag.Bool("fault_tolerance", false, "enable fault tolerance (requires etcd).") faultTolerance := flag.Bool("fault_tolerance", false, "enable fault tolerance (requires etcd).")
taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.") taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.")
taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.") taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.")
chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.") chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.")
flag.Parse() flag.Parse()
if *dataset == "" {
panic("no dataset specified.")
}
if *faultTolerance { if *faultTolerance {
panic("fault tolernance not implemented.") panic("fault tolernance not implemented.")
}
var chunks []master.Chunk
var paths []string
ss := strings.Split(*dataset, ",")
fmt.Println(ss)
for _, s := range ss {
match, err := filepath.Glob(s)
if err != nil {
panic(err)
}
paths = append(paths, match...)
}
if len(paths) == 0 {
panic("no valid datset specified.")
}
idx := 0
for _, path := range paths {
f, err := os.Open(path)
if err != nil {
panic(err)
}
index, err := recordio.LoadIndex(f)
if err != nil {
panic(err)
}
f.Close()
count := index.NumChunks()
for i := 0; i < count; i++ {
chunk := master.Chunk{
Idx: idx,
Path: path,
Index: *index.ChunkIndex(i),
}
chunks = append(chunks, chunk)
}
} }
s := master.NewService(chunks, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) s := master.NewService(*chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
err := rpc.Register(s) err := rpc.Register(s)
if err != nil { if err != nil {
panic(err) panic(err)
......
...@@ -2,6 +2,7 @@ package connection ...@@ -2,6 +2,7 @@ package connection
import ( import (
"errors" "errors"
"log"
"net/rpc" "net/rpc"
"sync" "sync"
) )
...@@ -21,6 +22,18 @@ func New() *Conn { ...@@ -21,6 +22,18 @@ func New() *Conn {
return c return c
} }
// Close closes the connection.
func (c *Conn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.client == nil {
return nil
}
return c.client.Close()
}
// Connect connects the connection to a address. // Connect connects the connection to a address.
func (c *Conn) Connect(addr string) error { func (c *Conn) Connect(addr string) error {
c.mu.Lock() c.mu.Lock()
...@@ -50,12 +63,20 @@ func (c *Conn) Connect(addr string) error { ...@@ -50,12 +63,20 @@ func (c *Conn) Connect(addr string) error {
c.waitConn = nil c.waitConn = nil
} }
} else { } else {
err := client.Close()
if err != nil {
log.Println(err)
}
return errors.New("client already set from a concurrent goroutine") return errors.New("client already set from a concurrent goroutine")
} }
return nil return nil
} }
// TODO(helin): refactor Call to be able to perform given retry
// policy.
// Call make a RPC call. // Call make a RPC call.
// //
// Call will be blocked until the connection to remote RPC service // Call will be blocked until the connection to remote RPC service
......
package master
import (
"log"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
)
// Addresser provide the address of the master server.
type Addresser interface {
Address() string
}
// Client is the client of the master server.
type Client struct {
conn *connection.Conn
}
// NewClient creates a new Client.
func NewClient(addr Addresser) *Client {
c := &Client{}
c.conn = connection.New()
go c.monitorMaster(addr)
return c
}
func (c *Client) monitorMaster(addr Addresser) {
lastMaster := ""
monitor := func() {
// get the lastest address of the master server,
// connect to the new address once address changed.
curMaster := addr.Address()
if curMaster != lastMaster {
if curMaster == "" {
err := c.conn.Close()
if err != nil {
log.Println(err)
}
} else {
err := c.conn.Connect(curMaster)
if err != nil {
log.Println(err)
// connect to addr failed, set
// to last known addr in order
// to retry next time.
curMaster = lastMaster
}
}
}
lastMaster = curMaster
}
monitor()
ticker := time.NewTicker(10 * time.Second)
for _ = range ticker.C {
monitor()
}
}
// SetDataset set dataset for the master server to dispatch.
//
// SetDataset can be call multiple times from different nodes. But
// only the first call will be honored.
func (c *Client) SetDataset(globPaths []string) error {
return c.conn.Call("Service.SetDataset", globPaths, nil)
}
// GetTask gets a new task from the master server.
func (c *Client) GetTask() (Task, error) {
var t Task
err := c.conn.Call("Service.GetTask", 0, &t)
return t, err
}
// TaskFinished tells the master server a task is finished.
func (c *Client) TaskFinished(taskID int) error {
return c.conn.Call("Service.TaskFinished", taskID, nil)
}
package master_test
import (
"fmt"
"net"
"net/http"
"net/rpc"
"os"
"strconv"
"strings"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/recordio"
)
const (
totalTask = 20
chunkPerTask = 10
)
var port int
func init() {
log.SetLevel(log.ErrorLevel)
l, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
ss := strings.Split(l.Addr().String(), ":")
p, err := strconv.Atoi(ss[len(ss)-1])
if err != nil {
panic(err)
}
port = p
go func(l net.Listener) {
s := master.NewService(chunkPerTask, time.Second, 1)
server := rpc.NewServer()
err := server.Register(s)
if err != nil {
panic(err)
}
mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server)
err = http.Serve(l, mux)
if err != nil {
panic(err)
}
}(l)
}
type addresser string
func (a addresser) Address() string {
return string(a)
}
func TestClientFull(t *testing.T) {
const p = "/tmp/master_client_test_0"
f, err := os.Create(p)
if err != nil {
panic(err)
}
for i := 0; i < totalTask*chunkPerTask; i++ {
w := recordio.NewWriter(f, -1, -1)
w.Write(nil)
// call Close to force RecordIO writing a chunk.
w.Close()
}
f.Close()
c := master.NewClient(addresser(fmt.Sprintf(":%d", port)))
c.SetDataset([]string{p})
checkOnePass := func(i int) {
var tasks []master.Task
for i := 0; i < totalTask; i++ {
task, err := c.GetTask()
if err != nil {
t.Fatal(i, err)
}
tasks = append(tasks, task)
}
_, err = c.GetTask()
if err == nil {
t.Fatal(i, "should get error.")
}
err = c.TaskFinished(tasks[0].ID)
if err != nil {
t.Fatal(err)
}
tasks = tasks[1:]
task, err := c.GetTask()
if err != nil {
t.Fatal(err)
}
tasks = append(tasks, task)
for _, task := range tasks {
err = c.TaskFinished(task.ID)
if err != nil {
t.Fatal(i, err)
}
}
}
for i := 0; i < 10; i++ {
checkOnePass(i)
}
}
...@@ -2,29 +2,25 @@ package master ...@@ -2,29 +2,25 @@ package master
import ( import (
"errors" "errors"
"log" "os"
"path/filepath"
"sync" "sync"
"time" "time"
"github.com/PaddlePaddle/recordio" log "github.com/sirupsen/logrus"
)
const (
targetTaskCount = 300
)
// errors "github.com/PaddlePaddle/recordio"
var (
ErrNoMoreTask = errors.New("no more task for current pass")
ErrPendingTaskNotFound = errors.New("pending task not found")
) )
// Service is the master server service. // Service is the master server service.
type Service struct { type Service struct {
chunksPerTask int
timeoutDur time.Duration timeoutDur time.Duration
timeoutMax int timeoutMax int
ready chan struct{}
mu sync.Mutex mu sync.Mutex
initDone bool
taskQueues taskQueues taskQueues taskQueues
} }
...@@ -55,7 +51,6 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { ...@@ -55,7 +51,6 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
if len(cur.Task.Chunks) > 0 { if len(cur.Task.Chunks) > 0 {
cur.Task.ID = id cur.Task.ID = id
id++
result = append(result, cur) result = append(result, cur)
} }
...@@ -63,21 +58,21 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { ...@@ -63,21 +58,21 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
} }
// NewService creates a new service. // NewService creates a new service.
func NewService(chunks []Chunk, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Service { func NewService(chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Service {
s := &Service{} s := &Service{}
s.chunksPerTask = chunksPerTask
s.timeoutDur = timeoutDur s.timeoutDur = timeoutDur
s.timeoutMax = timeoutMax s.timeoutMax = timeoutMax
s.taskQueues = taskQueues{} s.taskQueues = taskQueues{}
s.taskQueues.Pending = make(map[int]taskEntry) s.taskQueues.Pending = make(map[int]taskEntry)
s.taskQueues.Todo = partition(chunks, chunksPerTask) s.ready = make(chan struct{})
return s return s
} }
// Chunk is a chunk of data consisted of several data instances. // Chunk is a chunk of data consisted of several data instances.
type Chunk struct { type Chunk struct {
Idx int // index of the chunk within the file
Path string Path string
Index recordio.Index // block index Index recordio.Index // chunk index
} }
// Task is the basic unit of data instances assigned to trainers. // Task is the basic unit of data instances assigned to trainers.
...@@ -105,25 +100,87 @@ func (s *Service) snapshot() error { ...@@ -105,25 +100,87 @@ func (s *Service) snapshot() error {
return nil return nil
} }
// GetTask gets a new task from the service. func readChunks(globPaths []string) ([]Chunk, error) {
func (s *Service) GetTask(dummy int, task *Task) error { var chunks []Chunk
var paths []string
for _, s := range globPaths {
match, err := filepath.Glob(s)
if err != nil {
return nil, err
}
paths = append(paths, match...)
}
if len(paths) == 0 {
return nil, errors.New("no valid dataset specified")
}
for _, path := range paths {
f, err := os.Open(path)
if err != nil {
return nil, err
}
index, err := recordio.LoadIndex(f)
if err != nil {
return nil, err
}
err = f.Close()
if err != nil {
return nil, err
}
count := index.NumChunks()
for i := 0; i < count; i++ {
chunk := Chunk{
Path: path,
Index: *index.ChunkIndex(i),
}
chunks = append(chunks, chunk)
}
}
return chunks, nil
}
// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times. But only the first call will
// be honored.
func (s *Service) SetDataset(globPaths []string, dummy *int) error {
if len(globPaths) == 0 {
return errors.New("no dataset specified")
}
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.initDone {
// Already initialized. All trainer will call
// SetDataset, but we only handle the first one. Treat
// other calls as successful but do nothing.
return nil
}
if len(s.taskQueues.Todo) == 0 { chunks, err := readChunks(globPaths)
return ErrNoMoreTask if err != nil {
return err
} }
t := s.taskQueues.Todo[0] s.taskQueues.Todo = partition(chunks, s.chunksPerTask)
t.Epoch++
s.taskQueues.Todo = s.taskQueues.Todo[1:] err = s.snapshot()
s.taskQueues.Pending[t.Task.ID] = t
err := s.snapshot()
if err != nil { if err != nil {
log.Errorln(err)
return err return err
} }
time.AfterFunc(s.timeoutDur, func(taskID int, epoch int) func() { close(s.ready)
s.initDone = true
return nil
}
func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
return func() { return func() {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
...@@ -142,7 +199,7 @@ func (s *Service) GetTask(dummy int, task *Task) error { ...@@ -142,7 +199,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
defer func() { defer func() {
err := s.snapshot() err := s.snapshot()
if err != nil { if err != nil {
log.Println(err) log.Errorln(err)
} }
}() }()
...@@ -150,29 +207,98 @@ func (s *Service) GetTask(dummy int, task *Task) error { ...@@ -150,29 +207,98 @@ func (s *Service) GetTask(dummy int, task *Task) error {
t.NumTimeout++ t.NumTimeout++
if t.NumTimeout > s.timeoutMax { if t.NumTimeout > s.timeoutMax {
log.Warningf("Task %v failed %d times, discard.\n", t.Task, t.NumTimeout)
s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task) s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task)
return return
} }
log.Warningf("Task %v failed %d times, retry.\n", t.Task, t.NumTimeout)
s.taskQueues.Todo = append(s.taskQueues.Todo, t) s.taskQueues.Todo = append(s.taskQueues.Todo, t)
} }
}(t.Task.ID, t.Epoch)) }
// GetTask gets a new task from the service.
func (s *Service) GetTask(dummy int, task *Task) error {
select {
case <-s.ready:
}
s.mu.Lock()
defer s.mu.Unlock()
if len(s.taskQueues.Todo) == 0 {
if len(s.taskQueues.Done) == 0 {
if len(s.taskQueues.Pending) == 0 {
err := errors.New("all task failed")
log.Warningln(err)
return err
}
// TODO(helin): client need to retry in this
// error case. Gotcha: RPC client can't
// compare returned error with predefined
// errors like io.EOF, because the error
// instance deserialized from RPC is a
// different instance than the error defined
// in package. So we need to figure out a way
// for client to check this error correctly.
err := errors.New("no more available task")
log.Warningln(err)
return err
}
s.taskQueues.Todo = s.taskQueues.Done
s.taskQueues.Done = nil
log.Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.")
}
t := s.taskQueues.Todo[0]
t.Epoch++
s.taskQueues.Todo = s.taskQueues.Todo[1:]
s.taskQueues.Pending[t.Task.ID] = t
err := s.snapshot()
if err != nil {
return err
}
*task = t.Task
log.Infof("Task #%d dispatched\n", task.ID)
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Epoch))
return nil return nil
} }
// TaskFinished tell the service that a task is finished. // TaskFinished tell the service that a task is finished.
func (s *Service) TaskFinished(taskID int, dummy *int) error { func (s *Service) TaskFinished(taskID int, dummy *int) error {
select {
case <-s.ready:
}
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
log.Infof("Task %d finished\n", taskID)
t, ok := s.taskQueues.Pending[taskID] t, ok := s.taskQueues.Pending[taskID]
if !ok { if !ok {
return ErrPendingTaskNotFound err := errors.New("pending task not found")
log.Warningln(err)
return err
} }
// task finished, reset timeout // task finished, reset timeout
t.NumTimeout = 0 t.NumTimeout = 0
s.taskQueues.Done = append(s.taskQueues.Done, t) s.taskQueues.Done = append(s.taskQueues.Done, t)
delete(s.taskQueues.Pending, taskID) delete(s.taskQueues.Pending, taskID)
return s.snapshot()
if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 {
log.Infoln("No more todo and pending task, start a new pass.")
s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...)
s.taskQueues.Done = nil
}
err := s.snapshot()
if err != nil {
log.Errorln(err)
}
return err
} }
...@@ -9,5 +9,15 @@ project(cxx_go C Go) ...@@ -9,5 +9,15 @@ project(cxx_go C Go)
include(golang) include(golang)
include(flags) include(flags)
go_library(client STATIC) go_library(paddle_pserver_cclient STATIC)
if(PROJ_ROOT)
add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/trainer/libpaddle_pserver_cclient.a
COMMAND cp ${CMAKE_CURRENT_BINARY_DIR}/libpaddle_pserver_cclient.h ${PROJ_ROOT}/paddle/trainer/
COMMAND cp ${CMAKE_CURRENT_BINARY_DIR}/libpaddle_pserver_cclient.a ${PROJ_ROOT}/paddle/trainer/
WORKING_DIRECTORY ${PROJ_ROOT}/paddle
DEPENDS paddle_pserver_cclient)
add_custom_target(paddle_pserver_cclient_lib ALL DEPENDS ${PROJ_ROOT}/paddle/trainer/libpaddle_pserver_cclient.a)
endif(PROJ_ROOT)
add_subdirectory(test) add_subdirectory(test)
...@@ -19,21 +19,9 @@ typedef struct { ...@@ -19,21 +19,9 @@ typedef struct {
int content_len; int content_len;
} paddle_parameter, paddle_gradient; } paddle_parameter, paddle_gradient;
static inline void paddle_release_param(paddle_parameter* param) { typedef int paddle_pserver_client;
if (param != NULL) { #define PSERVER_ERROR -1
if (param->name != NULL) { #define PSERVER_OK 0
free(param->name);
}
if (param->content != NULL) {
free(param->content);
}
free(param);
}
}
typedef int client;
*/ */
import "C" import "C"
...@@ -48,10 +36,10 @@ import ( ...@@ -48,10 +36,10 @@ import (
var nullPtr = unsafe.Pointer(uintptr(0)) var nullPtr = unsafe.Pointer(uintptr(0))
var mu sync.Mutex var mu sync.Mutex
var handleMap = make(map[C.client]*pserver.Client) var handleMap = make(map[C.paddle_pserver_client]*pserver.Client)
var curHandle C.client var curHandle C.paddle_pserver_client
func add(c *pserver.Client) C.client { func add(c *pserver.Client) C.paddle_pserver_client {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
client := curHandle client := curHandle
...@@ -60,13 +48,13 @@ func add(c *pserver.Client) C.client { ...@@ -60,13 +48,13 @@ func add(c *pserver.Client) C.client {
return client return client
} }
func get(client C.client) *pserver.Client { func get(client C.paddle_pserver_client) *pserver.Client {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
return handleMap[client] return handleMap[client]
} }
func remove(client C.client) *pserver.Client { func remove(client C.paddle_pserver_client) *pserver.Client {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
h := handleMap[client] h := handleMap[client]
...@@ -100,7 +88,7 @@ func (l lister) List() []pserver.Server { ...@@ -100,7 +88,7 @@ func (l lister) List() []pserver.Server {
} }
//export paddle_new_pserver_client //export paddle_new_pserver_client
func paddle_new_pserver_client(addrs *C.char, selected int) C.client { func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_client {
a := C.GoString(addrs) a := C.GoString(addrs)
as := strings.Split(a, ",") as := strings.Split(a, ",")
servers := make([]pserver.Server, len(as)) servers := make([]pserver.Server, len(as))
...@@ -113,27 +101,27 @@ func paddle_new_pserver_client(addrs *C.char, selected int) C.client { ...@@ -113,27 +101,27 @@ func paddle_new_pserver_client(addrs *C.char, selected int) C.client {
} }
//export paddle_new_etcd_pserver_client //export paddle_new_etcd_pserver_client
func paddle_new_etcd_pserver_client(etcd_addr *C.char) C.client { func paddle_new_etcd_pserver_client(etcd_addr *C.char) C.paddle_pserver_client {
// TODO(helin): fault tolerant pserver client using etcd. // TODO(helin): fault tolerant pserver client using etcd.
panic("not implemented.") panic("not implemented.")
} }
//export paddle_pserver_client_release //export paddle_pserver_client_release
func paddle_pserver_client_release(client C.client) { func paddle_pserver_client_release(client C.paddle_pserver_client) {
remove(client) remove(client)
} }
//export paddle_begin_init_params //export paddle_begin_init_params
func paddle_begin_init_params(client C.client) C.int { func paddle_begin_init_params(client C.paddle_pserver_client) C.int {
c := get(client) c := get(client)
if selected := c.BeginInitParams(); selected { if selected := c.BeginInitParams(); selected {
return 1 return 1
} }
return 0 return C.PSERVER_OK
} }
//export paddle_init_param //export paddle_init_param
func paddle_init_param(client C.client, param C.paddle_parameter, param_config unsafe.Pointer, config_len C.int) C.int { func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, param_config unsafe.Pointer, config_len C.int) C.int {
et := pserver.ElementType(param.element_type) et := pserver.ElementType(param.element_type)
name := C.GoString(param.name) name := C.GoString(param.name)
content := cArrayToSlice(unsafe.Pointer(param.content), int(param.content_len)) content := cArrayToSlice(unsafe.Pointer(param.content), int(param.content_len))
...@@ -143,31 +131,41 @@ func paddle_init_param(client C.client, param C.paddle_parameter, param_config u ...@@ -143,31 +131,41 @@ func paddle_init_param(client C.client, param C.paddle_parameter, param_config u
} }
c := get(client) c := get(client)
err := c.InitParam(pc) err := c.InitParam(pc)
if err != nil { if err != nil {
if err.Error() == pserver.AlreadyInitialized {
log.Printf("parameter %s already initialized, treat paddle_init_param as sucessful.\n", name)
return C.PSERVER_OK
}
log.Println(err) log.Println(err)
return -1 return C.PSERVER_ERROR
} }
return 0 return C.PSERVER_OK
} }
//export paddle_finish_init_params //export paddle_finish_init_params
func paddle_finish_init_params(client C.client) C.int { func paddle_finish_init_params(client C.paddle_pserver_client) C.int {
c := get(client) c := get(client)
err := c.FinishInitParams() err := c.FinishInitParams()
if err != nil { if err != nil {
if err.Error() == pserver.AlreadyInitialized {
log.Println("parameters already initialized, treat paddle_finish_init_params as sucessful.")
return C.PSERVER_OK
}
log.Println(err) log.Println(err)
return -1 return C.PSERVER_ERROR
} }
return 0 return C.PSERVER_OK
} }
//export paddle_send_grads //export paddle_send_grads
func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C.int { func paddle_send_grads(client C.paddle_pserver_client, grads **C.paddle_gradient, total C.int) C.int {
var gs []pserver.Gradient var gs []pserver.Gradient
for i := 0; i < int(total); i++ { for i := 0; i < int(total); i++ {
grad := (*C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads)))) grad := *(**C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads))))
et := pserver.ElementType(grad.element_type) et := pserver.ElementType(grad.element_type)
name := C.GoString(grad.name) name := C.GoString(grad.name)
content := cArrayToSlice(unsafe.Pointer(grad.content), int(grad.content_len)) content := cArrayToSlice(unsafe.Pointer(grad.content), int(grad.content_len))
...@@ -178,83 +176,81 @@ func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C ...@@ -178,83 +176,81 @@ func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C
err := c.SendGrads(gs) err := c.SendGrads(gs)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return -1 return C.PSERVER_ERROR
} }
return 0 return C.PSERVER_OK
} }
//export paddle_get_params //export paddle_get_params
func paddle_get_params(client C.client, names **C.char, dst **C.paddle_parameter, total C.int) C.int { func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, total C.int) C.int {
var ns []string var ns []string
for i := 0; i < int(total); i++ { for i := 0; i < int(total); i++ {
name := *(**C.char)(unsafe.Pointer((uintptr(unsafe.Pointer(names)) + uintptr(i)*unsafe.Sizeof(*names)))) param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
ns = append(ns, C.GoString(name)) ns = append(ns, C.GoString(param.name))
} }
c := get(client) c := get(client)
ps, err := c.GetParams(ns) ps, err := c.GetParams(ns)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return -1 return C.PSERVER_ERROR
} }
for i := 0; i < int(total); i++ { if len(ps) != len(ns) {
if i >= len(ps) { pn := make([]string, len(ps))
break for i, p := range ps {
pn[i] = p.Name
}
log.Printf("pserver returned wrong number of parameters. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", "))
return C.PSERVER_ERROR
} }
for i := range ps {
if ns[i] != ps[i].Name {
pn := make([]string, len(ps))
for i, p := range ps {
pn[i] = p.Name
}
log.Printf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", "))
return C.PSERVER_ERROR
}
}
for i := 0; i < int(total); i++ {
p := ps[i] p := ps[i]
param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst)))) param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
nameReady := false
contentAllocated := false
if unsafe.Pointer(param) == nullPtr { if unsafe.Pointer(param) == nullPtr {
param = (*C.paddle_parameter)(C.calloc(1, C.size_t(unsafe.Sizeof(*param)))) log.Println("must pre-allocate parameter.")
} else { return C.PSERVER_ERROR
if unsafe.Pointer(param.name) != nullPtr {
if n := C.GoString(param.name); n != p.Name {
log.Println("Warning: the pre-allocated parameter name does not match the parameter name, it will be freed.", n, p.Name)
C.free(unsafe.Pointer(param.name))
} else { } else {
nameReady = true
}
}
if unsafe.Pointer(param.content) != nullPtr { if unsafe.Pointer(param.content) != nullPtr {
if int(param.content_len) == len(p.Content) { if int(param.content_len) != len(p.Content) {
contentAllocated = true log.Printf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content))
} else { return C.PSERVER_ERROR
log.Println("Warning: the pre-allocated content len does not match parameter content len, the pre-allocated content will be freed.", param.content_len, len(p.Content))
C.free(unsafe.Pointer(param.content))
} }
} }
} }
if !nameReady {
param.name = C.CString(p.Name)
}
if !contentAllocated {
param.content = (*C.uchar)(C.malloc(C.size_t(len(p.Content))))
}
C.memcpy(unsafe.Pointer(param.content), unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content))) C.memcpy(unsafe.Pointer(param.content), unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
param.content_len = C.int(len(p.Content)) param.content_len = C.int(len(p.Content))
param.element_type = C.paddle_element_type(p.ElementType) param.element_type = C.paddle_element_type(p.ElementType)
} }
return 0 return C.PSERVER_OK
} }
//export paddle_save_model //export paddle_save_model
func paddle_save_model(client C.client, path *C.char) C.int { func paddle_save_model(client C.paddle_pserver_client, path *C.char) C.int {
p := C.GoString(path) p := C.GoString(path)
c := get(client) c := get(client)
err := c.Save(p) err := c.Save(p)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return -1 return C.PSERVER_ERROR
} }
return 0 return C.PSERVER_OK
} }
func main() {} // Required but ignored func main() {} // Required but ignored
cmake_minimum_required(VERSION 3.0) cmake_minimum_required(VERSION 3.0)
include_directories(${CMAKE_BINARY_DIR})
add_executable(main main.c) add_executable(main main.c)
add_dependencies(main client) add_dependencies(main paddle_pserver_cclient)
add_executable(test_cclient test_cclient.c)
add_dependencies(test_cclient paddle_pserver_cclient)
if(APPLE) if(APPLE)
set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security") set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security")
else()
set(CMAKE_EXE_LINKER_FLAGS "-pthread")
endif() endif()
target_link_libraries(main ${CMAKE_BINARY_DIR}/libclient.a)
if(PROJ_ROOT)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/..)
target_link_libraries(main ${CMAKE_CURRENT_BINARY_DIR}/../libpaddle_pserver_cclient.a pthread)
target_link_libraries(test_cclient ${CMAKE_CURRENT_BINARY_DIR}/../libpaddle_pserver_cclient.a pthread)
else(PROJ_ROOT)
include_directories(${CMAKE_BINARY_DIR})
target_link_libraries(main ${CMAKE_BINARY_DIR}/libpaddle_pserver_cclient.a pthread)
target_link_libraries(test_cclient ${CMAKE_BINARY_DIR}/libpaddle_pserver_cclient.a pthread)
endif(PROJ_ROOT)
#include <stdio.h> #include <stdio.h>
#include "libclient.h" #include "libpaddle_pserver_cclient.h"
void fail() { // TODO(helin): Fix: gtest using cmake is not working, using this
// TODO(helin): fix: gtest using cmake is not working, using this // hacky way for now.
// hacky way for now. #define fail() \
printf("test failed.\n"); fprintf(stderr, "info: %s:%d: ", __FILE__, __LINE__); \
exit(-1); exit(-1);
void sendGrads(paddle_pserver_client c) {
unsigned char grad_a[2000] = {2};
unsigned char grad_b[3000] = {3};
paddle_gradient grad1 = {
"param_a", PADDLE_ELEMENT_TYPE_FLOAT32, grad_a, 2000};
paddle_gradient grad2 = {
"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, grad_b, 3000};
paddle_gradient* grads[2] = {&grad1, &grad2};
if (paddle_send_grads(c, grads, 2)) {
fail();
}
}
void getParams(paddle_pserver_client c) {
paddle_parameter param_a;
paddle_parameter param_b;
char name_a[] = "param_a";
char name_b[] = "param_b";
// Must pre-allocate the prameter content before calling paddle_get_params.
unsigned char content_a[2000] = {};
unsigned char content_b[3000] = {};
param_a.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param_a.name = name_a;
param_a.content = content_a;
param_a.content_len = 2000;
param_b.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param_b.name = name_b;
param_b.content = content_b;
param_b.content_len = 3000;
paddle_parameter* params[2] = {&param_a, &param_b};
if (paddle_get_params(c, params, 2)) {
fail();
}
} }
int main() { int main() {
char addr[] = "localhost:3000"; char addr[] = "localhost:3000";
client c = paddle_new_pserver_client(addr, 1); paddle_pserver_client c = paddle_new_pserver_client(addr, 1);
retry: retry:
if (paddle_begin_init_params(c)) { if (paddle_begin_init_params(c)) {
paddle_parameter param; paddle_parameter param;
char name_a[] = "param_a"; char name_a[] = "param_a";
char name_b[] = "param_b"; char name_b[] = "param_b";
unsigned char content[] = {0x00, 0x11, 0x22}; unsigned char content_a[2000] = {1};
unsigned char content_b[3000] = {0};
param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32; param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param.name = name_a; param.name = name_a;
param.content = content; param.content = content_a;
param.content_len = 3; param.content_len = 2000;
if (paddle_init_param(c, param, NULL, 0) != 0) { int error = paddle_init_param(c, param, NULL, 0);
if (error != 0) {
goto retry; goto retry;
} }
param.element_type = PADDLE_ELEMENT_TYPE_INT32;
param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param.name = name_b; param.name = name_b;
param.content = content; param.content = content_b;
param.content_len = 3; param.content_len = 3000;
if (paddle_init_param(c, param, NULL, 0) != 0) { error = paddle_init_param(c, param, NULL, 0);
if (error != 0) {
goto retry; goto retry;
} }
if (paddle_finish_init_params(c) != 0) {
goto retry;
}
} else {
fail();
}
unsigned char content[] = {0x00, 0x11, 0x22};
paddle_gradient grads[2] = {
{"param_a", PADDLE_ELEMENT_TYPE_INT32, content, 3},
{"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3}};
if (!paddle_send_grads(c, grads, 2)) { error = paddle_finish_init_params(c);
fail(); if (error != 0) {
goto retry;
} }
paddle_parameter* params[2] = {NULL, NULL};
char* names[] = {"param_a", "param_b"};
if (!paddle_get_params(c, names, params, 2)) {
fail();
} }
// get parameters again by reusing the allocated parameter buffers. int i;
if (!paddle_get_params(c, names, params, 2)) { for (i = 0; i < 100; i++) {
fail(); sendGrads(c);
getParams(c);
} }
paddle_release_param(params[0]); if (paddle_save_model(c, "/tmp/")) {
paddle_release_param(params[1]);
if (!paddle_save_model(c, "/tmp/")) {
fail(); fail();
} }
......
#include <stdio.h>
#include <stdlib.h>
#include "libpaddle_pserver_cclient.h"
typedef float real;
void fail() {
// TODO(helin): fix: gtest using cmake is not working, using this
// hacky way for now.
printf("test failed.\n");
exit(-1);
}
void print_parameter(paddle_gradient* param) {
if (param == NULL) {
printf("param is NULL!!\n");
} else {
printf("==== parameter ====\n");
printf("name: %s\n", param->name);
printf("content_len: %d\n", param->content_len);
printf("content_type: %d\n", param->element_type);
int i;
for (i = 0; i < param->content_len / (int)sizeof(real); ++i) {
printf("%f ", ((float*)param->content)[i]);
}
printf("\n\n");
}
}
int main() {
char addr[] = "localhost:3000";
paddle_pserver_client c = paddle_new_pserver_client(addr, 1);
char* names[] = {"param_a", "param_b"};
retry:
printf("init parameter to pserver:\n");
real param_content1[] = {0.1, 0.2, 0.3};
real param_content2[] = {0.4, 0.5, 0.6};
paddle_parameter** params =
(paddle_parameter**)malloc(sizeof(paddle_parameter*) * 2);
params[0] = (paddle_parameter*)malloc(sizeof(paddle_parameter));
params[0]->name = names[0];
params[0]->content = (unsigned char*)param_content1;
params[0]->content_len = 3 * sizeof(real);
params[0]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
params[1] = (paddle_parameter*)malloc(sizeof(paddle_parameter));
params[1]->name = names[1];
params[1]->content = (unsigned char*)param_content2;
params[1]->content_len = 3 * sizeof(real);
params[1]->element_type = PADDLE_ELEMENT_TYPE_INT32;
if (paddle_begin_init_params(c)) {
if (paddle_init_param(c, *params[0], NULL, 0) != 0) {
goto retry;
}
if (paddle_init_param(c, *params[1], NULL, 0) != 0) {
goto retry;
}
if (paddle_finish_init_params(c) != 0) {
goto retry;
}
} else {
fail();
}
printf("get inited parameters from pserver:\n");
// get parameters again by reusing the allocated parameter buffers.
if (paddle_get_params(c, params, 2) != 0) {
fail();
}
print_parameter(params[0]);
print_parameter(params[1]);
printf("send gradient to pserver:\n");
real gradient_content1[] = {0.01, 0.02, 0.03};
real gradinet_content2[] = {0.04, 0.05, 0.06};
paddle_gradient** grads =
(paddle_gradient**)malloc(sizeof(paddle_gradient*) * 2);
grads[0] = (paddle_gradient*)malloc(sizeof(paddle_gradient));
grads[0]->name = names[0];
grads[0]->content = (unsigned char*)gradient_content1;
grads[0]->content_len = 3 * sizeof(real);
grads[0]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
grads[1] = (paddle_gradient*)malloc(sizeof(paddle_gradient));
grads[1]->name = names[1];
grads[1]->content = (unsigned char*)gradinet_content2;
grads[1]->content_len = 3 * sizeof(real);
grads[1]->element_type = PADDLE_ELEMENT_TYPE_INT32;
printf("print gradient sent to pserver:\n");
print_parameter(grads[0]);
print_parameter(grads[1]);
if (paddle_send_grads(c, grads, 2) != 0) {
fail();
}
printf("get updated parameters from pserver:\n");
// get parameters again by reusing the allocated parameter buffers.
if (paddle_get_params(c, params, 2) != 0) {
fail();
}
print_parameter(params[0]);
print_parameter(params[1]);
if (paddle_save_model(c, "/tmp/") != 0) {
fail();
}
return 0;
}
import paddle.v2 as paddle
import gzip
def softmax_regression(img):
predict = paddle.layer.fc(input=img,
size=10,
act=paddle.activation.Softmax())
return predict
def multilayer_perceptron(img):
# The first fully-connected layer
hidden1 = paddle.layer.fc(input=img, size=128, act=paddle.activation.Relu())
# The second fully-connected layer and the according activation function
hidden2 = paddle.layer.fc(input=hidden1,
size=64,
act=paddle.activation.Relu())
# The thrid fully-connected layer, note that the hidden size should be 10,
# which is the number of unique digits
predict = paddle.layer.fc(input=hidden2,
size=10,
act=paddle.activation.Softmax())
return predict
def convolutional_neural_network(img):
# first conv layer
conv_pool_1 = paddle.networks.simple_img_conv_pool(
input=img,
filter_size=5,
num_filters=20,
num_channel=1,
pool_size=2,
pool_stride=2,
act=paddle.activation.Tanh())
# second conv layer
conv_pool_2 = paddle.networks.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
num_filters=50,
num_channel=20,
pool_size=2,
pool_stride=2,
act=paddle.activation.Tanh())
# The first fully-connected layer
fc1 = paddle.layer.fc(input=conv_pool_2,
size=128,
act=paddle.activation.Tanh())
# The softmax layer, note that the hidden size should be 10,
# which is the number of unique digits
predict = paddle.layer.fc(input=fc1,
size=10,
act=paddle.activation.Softmax())
return predict
def main():
paddle.init(use_gpu=False, trainer_count=1)
# define network topology
images = paddle.layer.data(
name='pixel', type=paddle.data_type.dense_vector(784))
label = paddle.layer.data(
name='label', type=paddle.data_type.integer_value(10))
# Here we can build the prediction network in different ways. Please
# choose one by uncomment corresponding line.
predict = softmax_regression(images)
#predict = multilayer_perceptron(images)
#predict = convolutional_neural_network(images)
cost = paddle.layer.classification_cost(input=predict, label=label)
parameters = paddle.parameters.create(cost)
optimizer = paddle.optimizer.Momentum(
learning_rate=0.1 / 128.0,
momentum=0.9,
regularization=paddle.optimizer.L2Regularization(rate=0.0005 * 128))
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=optimizer,
is_local=False,
pserver_spec="localhost:3000")
lists = []
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 1000 == 0:
print "Pass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
elif isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=paddle.batch(
paddle.dataset.mnist.test(), batch_size=128))
print "Test with Pass %d, Cost %f, %s\n" % (
event.pass_id, result.cost, result.metrics)
lists.append((event.pass_id, result.cost,
result.metrics['classification_error_evaluator']))
trainer.train(
reader=paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192),
batch_size=128),
event_handler=event_handler,
num_passes=100)
# find the best pass
best = sorted(lists, key=lambda list: float(list[1]))[0]
print 'Best pass is %s, testing Avgcost is %s' % (best[0], best[1])
print 'The classification accuracy is %.2f%%' % (100 - float(best[2]) * 100)
test_creator = paddle.dataset.mnist.test()
test_data = []
for item in test_creator():
test_data.append((item[0], ))
if len(test_data) == 100:
break
# output is a softmax layer. It returns probabilities.
# Shape should be (100, 10)
probs = paddle.infer(
output_layer=predict, parameters=parameters, input=test_data)
print probs.shape
if __name__ == '__main__':
main()
import paddle.v2 as paddle
import paddle.v2.dataset.uci_housing as uci_housing
def main():
# init
paddle.init(use_gpu=False, trainer_count=1)
# network config
x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13))
y_predict = paddle.layer.fc(input=x,
param_attr=paddle.attr.Param(name='w'),
size=1,
act=paddle.activation.Linear(),
bias_attr=paddle.attr.Param(name='b'))
y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1))
cost = paddle.layer.mse_cost(input=y_predict, label=y)
# create parameters
parameters = paddle.parameters.create(cost)
# create optimizer
optimizer = paddle.optimizer.Momentum(momentum=0)
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=optimizer,
is_local=False,
pserver_spec="localhost:3000")
# event_handler to print training and testing info
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f" % (
event.pass_id, event.batch_id, event.cost)
if isinstance(event, paddle.event.EndPass):
if (event.pass_id + 1) % 10 == 0:
result = trainer.test(
reader=paddle.batch(
uci_housing.test(), batch_size=2),
feeding={'x': 0,
'y': 1})
print "Test %d, %.2f" % (event.pass_id, result.cost)
# training
trainer.train(
reader=paddle.batch(
paddle.reader.shuffle(
uci_housing.train(), buf_size=500),
batch_size=2),
feeding={'x': 0,
'y': 1},
event_handler=event_handler,
num_passes=30)
if __name__ == '__main__':
main()
...@@ -6,7 +6,7 @@ import ( ...@@ -6,7 +6,7 @@ import (
"sort" "sort"
"time" "time"
"github.com/PaddlePaddle/Paddle/go/pserver/internal/connection" "github.com/PaddlePaddle/Paddle/go/connection"
) )
// TODO(helin): add RPC call retry logic // TODO(helin): add RPC call retry logic
...@@ -47,7 +47,7 @@ func NewClient(l Lister, pserverNum int, sel Selector) *Client { ...@@ -47,7 +47,7 @@ func NewClient(l Lister, pserverNum int, sel Selector) *Client {
// monitorPservers monitors pserver addresses, and updates connection // monitorPservers monitors pserver addresses, and updates connection
// when the address changes. // when the address changes.
func (c *Client) monitorPservers(l Lister, pserverNum int) { func (c *Client) monitorPservers(l Lister, pserverNum int) {
knownServers := make([]Server, pserverNum) lastServers := make([]Server, pserverNum)
ticker := time.NewTicker(10 * time.Second) ticker := time.NewTicker(10 * time.Second)
monitor := func() { monitor := func() {
curServers := make([]Server, pserverNum) curServers := make([]Server, pserverNum)
...@@ -56,8 +56,20 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) { ...@@ -56,8 +56,20 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
curServers[l.Index] = l curServers[l.Index] = l
} }
for i := range knownServers { for i := range lastServers {
if knownServers[i].Addr != curServers[i].Addr { if lastServers[i].Addr == curServers[i].Addr {
continue
}
if curServers[i].Addr == "" {
err := c.pservers[i].Close()
if err != nil {
log.Println(err)
}
continue
}
err := c.pservers[i].Connect(curServers[i].Addr) err := c.pservers[i].Connect(curServers[i].Addr)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
...@@ -65,16 +77,16 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) { ...@@ -65,16 +77,16 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
// connect to addr failed, set // connect to addr failed, set
// to last known addr in order // to last known addr in order
// to retry next time. // to retry next time.
curServers[i].Addr = knownServers[i].Addr curServers[i].Addr = lastServers[i].Addr
}
} }
} }
knownServers = curServers lastServers = curServers
} }
monitor() monitor()
for _ = range ticker.C { for range ticker.C {
monitor() monitor()
} }
} }
...@@ -93,16 +105,14 @@ func (c *Client) BeginInitParams() bool { ...@@ -93,16 +105,14 @@ func (c *Client) BeginInitParams() bool {
// InitParam initializes the parameter on parameter servers. // InitParam initializes the parameter on parameter servers.
func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error { func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error {
var dummy int return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, nil)
return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, &dummy)
} }
// FinishInitParams tells parameter servers client has sent all // FinishInitParams tells parameter servers client has sent all
// parameters to parameter servers as initialization. // parameters to parameter servers as initialization.
func (c *Client) FinishInitParams() error { func (c *Client) FinishInitParams() error {
for _, p := range c.pservers { for _, p := range c.pservers {
var dummy int err := p.Call("Service.FinishInitParams", 0, nil)
err := p.Call("Service.FinishInitParams", dummy, &dummy)
if err != nil { if err != nil {
return err return err
} }
...@@ -116,8 +126,7 @@ func (c *Client) SendGrads(grads []Gradient) error { ...@@ -116,8 +126,7 @@ func (c *Client) SendGrads(grads []Gradient) error {
errCh := make(chan error, len(grads)) errCh := make(chan error, len(grads))
for _, g := range grads { for _, g := range grads {
go func(g Gradient) { go func(g Gradient) {
var dummy int err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, nil)
err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, &dummy)
errCh <- err errCh <- err
}(g) }(g)
} }
...@@ -196,8 +205,7 @@ func (c *Client) Save(path string) error { ...@@ -196,8 +205,7 @@ func (c *Client) Save(path string) error {
errCh := make(chan error, len(c.pservers)) errCh := make(chan error, len(c.pservers))
for _, p := range c.pservers { for _, p := range c.pservers {
var dummy int err := p.Call("Service.Save", path, nil)
err := p.Call("Service.Save", path, &dummy)
errCh <- err errCh <- err
} }
......
...@@ -117,7 +117,7 @@ func TestClientFull(t *testing.T) { ...@@ -117,7 +117,7 @@ func TestClientFull(t *testing.T) {
for i := range params { for i := range params {
if names[i] != params[i].Name { if names[i] != params[i].Name {
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i]) t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i].Name)
} }
} }
} }
...@@ -32,7 +32,13 @@ int update_SGD(void* optimizer, ...@@ -32,7 +32,13 @@ int update_SGD(void* optimizer,
const void* gradient, const void* gradient,
int num_bytes) { int num_bytes) {
SGD_optimizer* o = (SGD_optimizer*)optimizer; SGD_optimizer* o = (SGD_optimizer*)optimizer;
// TODO float* parameter = (float*)buffer;
float* grad = (float*)gradient;
int i;
for (i = 0; i < num_bytes / sizeof(float); ++i) {
parameter[i] -= o->learning_rate * grad[i];
}
return 0; return 0;
} }
......
...@@ -9,8 +9,10 @@ import ( ...@@ -9,8 +9,10 @@ import (
// ElementType is the type of elements of a Parameter. // ElementType is the type of elements of a Parameter.
type ElementType int type ElementType int
var ErrAlreadyInitialized = errors.New("pserver already initialized") const (
var ErrUninitialized = errors.New("pserver not fully initialized") AlreadyInitialized = "pserver already initialized"
Uninitialized = "pserver not fully initialized"
)
// Supported element types // Supported element types
const ( const (
...@@ -49,7 +51,7 @@ type Service struct { ...@@ -49,7 +51,7 @@ type Service struct {
// NewService creates a new service. // NewService creates a new service.
func NewService() *Service { func NewService() *Service {
s := &Service{opt: newOptimizer(sgd, 0.01)} s := &Service{opt: newOptimizer(sgd, 0.005)}
s.paramMap = make(map[string]Parameter) s.paramMap = make(map[string]Parameter)
s.initialized = make(chan struct{}) s.initialized = make(chan struct{})
return s return s
...@@ -59,7 +61,7 @@ func NewService() *Service { ...@@ -59,7 +61,7 @@ func NewService() *Service {
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error { func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
return ErrAlreadyInitialized return errors.New(AlreadyInitialized)
default: default:
} }
...@@ -80,7 +82,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er ...@@ -80,7 +82,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error { func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
return ErrAlreadyInitialized return errors.New(AlreadyInitialized)
default: default:
} }
...@@ -94,7 +96,7 @@ func (s *Service) SendGrad(g Gradient, dummy *int) error { ...@@ -94,7 +96,7 @@ func (s *Service) SendGrad(g Gradient, dummy *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
default: default:
return ErrUninitialized return errors.New(Uninitialized)
} }
s.mu.Lock() s.mu.Lock()
......
...@@ -15,8 +15,7 @@ func TestFull(t *testing.T) { ...@@ -15,8 +15,7 @@ func TestFull(t *testing.T) {
p.Name = "param_a" p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32 p.ElementType = pserver.Int32
var dummy int err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
err := s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
...@@ -25,12 +24,12 @@ func TestFull(t *testing.T) { ...@@ -25,12 +24,12 @@ func TestFull(t *testing.T) {
p1.Name = "param_b" p1.Name = "param_b"
p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
p1.ElementType = pserver.Float32 p1.ElementType = pserver.Float32
err = s.InitParam(pserver.ParameterWithConfig{p1, nil}, &dummy) err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: nil}, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
err = s.FinishInitParams(0, &dummy) err = s.FinishInitParams(0, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
...@@ -46,11 +45,11 @@ func TestFull(t *testing.T) { ...@@ -46,11 +45,11 @@ func TestFull(t *testing.T) {
} }
g1, g2 := pserver.Gradient(p1), pserver.Gradient(p) g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)
err = s.SendGrad(g1, &dummy) err = s.SendGrad(g1, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
err = s.SendGrad(g2, &dummy) err = s.SendGrad(g2, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
...@@ -74,23 +73,21 @@ func TestFull(t *testing.T) { ...@@ -74,23 +73,21 @@ func TestFull(t *testing.T) {
func TestMultipleInit(t *testing.T) { func TestMultipleInit(t *testing.T) {
s := pserver.NewService() s := pserver.NewService()
var dummy int err := s.FinishInitParams(0, nil)
err := s.FinishInitParams(0, &dummy)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
err = s.FinishInitParams(0, &dummy) err = s.FinishInitParams(0, nil)
if err != pserver.ErrAlreadyInitialized { if err.Error() != pserver.AlreadyInitialized {
t.FailNow() t.FailNow()
} }
} }
func TestUninitialized(t *testing.T) { func TestUninitialized(t *testing.T) {
s := pserver.NewService() s := pserver.NewService()
var dummy int err := s.SendGrad(pserver.Gradient{}, nil)
err := s.SendGrad(pserver.Gradient{}, &dummy) if err.Error() != pserver.Uninitialized {
if err != pserver.ErrUninitialized {
t.FailNow() t.FailNow()
} }
} }
...@@ -98,13 +95,14 @@ func TestUninitialized(t *testing.T) { ...@@ -98,13 +95,14 @@ func TestUninitialized(t *testing.T) {
func TestBlockUntilInitialized(t *testing.T) { func TestBlockUntilInitialized(t *testing.T) {
s := pserver.NewService() s := pserver.NewService()
ch := make(chan struct{}, 2) ch := make(chan struct{}, 2)
errCh := make(chan error, 2)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go func() { go func() {
var param pserver.Parameter var param pserver.Parameter
err := s.GetParam("param_a", &param) err := s.GetParam("param_a", &param)
if err != nil { if err != nil {
t.FailNow() errCh <- err
} }
wg.Done() wg.Done()
ch <- struct{}{} ch <- struct{}{}
...@@ -112,10 +110,9 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -112,10 +110,9 @@ func TestBlockUntilInitialized(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
var dummy int err := s.Save("", nil)
err := s.Save("", &dummy)
if err != nil { if err != nil {
t.FailNow() errCh <- err
} }
wg.Done() wg.Done()
ch <- struct{}{} ch <- struct{}{}
...@@ -127,6 +124,8 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -127,6 +124,8 @@ func TestBlockUntilInitialized(t *testing.T) {
case <-ch: case <-ch:
// some function returned before initialization is completed. // some function returned before initialization is completed.
t.FailNow() t.FailNow()
case <-errCh:
t.FailNow()
default: default:
} }
...@@ -134,13 +133,12 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -134,13 +133,12 @@ func TestBlockUntilInitialized(t *testing.T) {
p.Name = "param_a" p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32 p.ElementType = pserver.Int32
var dummy int err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
err := s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
err = s.FinishInitParams(0, &dummy) err = s.FinishInitParams(0, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
......
...@@ -16,7 +16,7 @@ set(API_HEADER ...@@ -16,7 +16,7 @@ set(API_HEADER
Internal.h) Internal.h)
add_library(paddle_api STATIC ${API_SOURCES}) add_library(paddle_api STATIC ${API_SOURCES})
add_dependencies(paddle_api gen_proto_cpp) add_dependencies(paddle_api gen_proto_cpp paddle_pserver_cclient_lib)
INCLUDE(${SWIG_USE_FILE}) INCLUDE(${SWIG_USE_FILE})
INCLUDE_DIRECTORIES(${PROJ_ROOT}/paddle) INCLUDE_DIRECTORIES(${PROJ_ROOT}/paddle)
...@@ -45,7 +45,7 @@ SET(SWIG_MODULE_swig_paddle_EXTRA_DEPS ...@@ -45,7 +45,7 @@ SET(SWIG_MODULE_swig_paddle_EXTRA_DEPS
) )
IF(APPLE) IF(APPLE)
SET(MACOS_LD_FLAGS "-undefined dynamic_lookup -Wl,-all_load") SET(MACOS_LD_FLAGS "-undefined dynamic_lookup -Wl,-all_load -framework CoreFoundation -framework Security")
ELSE(APPLE) ELSE(APPLE)
SET(START_GROUP "-Xlinker -start-group") SET(START_GROUP "-Xlinker -start-group")
SET(END_GROUP "-Xlinker -end-group") SET(END_GROUP "-Xlinker -end-group")
......
...@@ -179,6 +179,7 @@ namespace std { ...@@ -179,6 +179,7 @@ namespace std {
%newobject ParameterOptimizer::needSpecialTraversal; %newobject ParameterOptimizer::needSpecialTraversal;
%newobject ParameterUpdater::createLocalUpdater; %newobject ParameterUpdater::createLocalUpdater;
%newobject ParameterUpdater::createRemoteUpdater; %newobject ParameterUpdater::createRemoteUpdater;
%newobject ParameterUpdater::createNewRemoteUpdater;
%feature("director") UpdateCallback; %feature("director") UpdateCallback;
%feature("autodoc", 1); // To generate method stub, for code hint in ide %feature("autodoc", 1); // To generate method stub, for code hint in ide
......
...@@ -841,6 +841,8 @@ public: ...@@ -841,6 +841,8 @@ public:
static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config, static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config,
int passCount, int passCount,
bool useSparseUpdater); bool useSparseUpdater);
static ParameterUpdater* createNewRemoteUpdater(
OptimizationConfig* config, const std::string pserverSpec);
~ParameterUpdater(); ~ParameterUpdater();
/** /**
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "PaddleAPI.h" #include "PaddleAPI.h"
#include "PaddleAPIPrivate.h" #include "PaddleAPIPrivate.h"
#include "paddle/trainer/NewRemoteParameterUpdater.h"
#include "paddle/trainer/RemoteParameterUpdater.h" #include "paddle/trainer/RemoteParameterUpdater.h"
#include "paddle/trainer/ThreadParameterUpdater.h" #include "paddle/trainer/ThreadParameterUpdater.h"
...@@ -28,6 +29,14 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater( ...@@ -28,6 +29,14 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater(
return updater; return updater;
} }
ParameterUpdater *ParameterUpdater::createNewRemoteUpdater(
OptimizationConfig *config, const std::string pserverSpec) {
auto updater = new ParameterUpdater();
updater->m->updater.reset(new paddle::NewRemoteParameterUpdater(
config->m->getConfig(), pserverSpec));
return updater;
}
ParameterUpdater *ParameterUpdater::createRemoteUpdater( ParameterUpdater *ParameterUpdater::createRemoteUpdater(
OptimizationConfig *config, int passCount, bool useSparseUpdater) { OptimizationConfig *config, int passCount, bool useSparseUpdater) {
auto updater = new ParameterUpdater(); auto updater = new ParameterUpdater();
......
此差异已折叠。
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <float.h>
#include <algorithm>
#include <vector>
#include "paddle/math/Matrix.h"
using std::vector;
using std::pair;
using std::map;
namespace paddle {
template <typename T>
struct BBoxBase {
BBoxBase(T xMin, T yMin, T xMax, T yMax)
: xMin(xMin), yMin(yMin), xMax(xMax), yMax(yMax), isDifficult(false) {}
BBoxBase() {}
T getWidth() const { return xMax - xMin; }
T getHeight() const { return yMax - yMin; }
T getCenterX() const { return (xMin + xMax) / 2; }
T getCenterY() const { return (yMin + yMax) / 2; }
T getArea() const { return getWidth() * getHeight(); }
// coordinate of bounding box
T xMin;
T yMin;
T xMax;
T yMax;
// whether difficult object (e.g. object with heavy occlusion is difficult)
bool isDifficult;
};
struct NormalizedBBox : BBoxBase<real> {
NormalizedBBox() : BBoxBase<real>() {}
};
enum PermMode { kNCHWToNHWC, kNHWCToNCHW };
/**
* @brief First permute input maxtrix then append to output matrix
*/
size_t appendWithPermute(const Matrix& inMatrix,
size_t height,
size_t width,
size_t outTotalSize,
size_t outOffset,
size_t batchSize,
Matrix& outMatrix,
PermMode permMode);
/**
* @brief First permute input maxtrix then decompose to output
*/
size_t decomposeWithPermute(const Matrix& inMatrix,
size_t height,
size_t width,
size_t totalSize,
size_t offset,
size_t batchSize,
Matrix& outMatrix,
PermMode permMode);
/**
* @brief Compute jaccard overlap between two bboxes.
* @param bbox1 The first bbox
* @param bbox2 The second bbox
*/
real jaccardOverlap(const NormalizedBBox& bbox1, const NormalizedBBox& bbox2);
/**
* @brief Compute offset parameters between prior bbox and ground truth bbox
* and variances of prior bbox are considered
* @param priorBBox Input prior bbox
* @param priorBBoxVar Variance parameters of prior bbox
* @param gtBBox Groundtruth bbox
* @param outVec Output vector
*/
void encodeBBoxWithVar(const NormalizedBBox& priorBBox,
const vector<real>& priorBBoxVar,
const NormalizedBBox& gtBBox,
vector<real>& outVec);
/**
* @brief Decode prior bbox with offset parameters
* and variances of prior bbox are considered
* @param priorBBox Prior bbox to be decoded
* @param priorBBoxVar Variance parameters of prior bbox
* @param locPredData Offset parameters
*/
NormalizedBBox decodeBBoxWithVar(const NormalizedBBox& priorBBox,
const vector<real>& priorBBoxVar,
const vector<real>& locPredData);
/**
* @brief Extract bboxes from prior matrix, the layout is
* xmin1 | ymin1 | xmax1 | ymax1 | xmin1Var | ymin1Var | xmax1Var | ymax1Var ...
* @param priorData Matrix of prior value
* @param numBBoxes Number of bbox to be extracted
* @param bboxVec Append to the vector
*/
void getBBoxFromPriorData(const real* priorData,
const size_t numBBoxes,
vector<NormalizedBBox>& bboxVec);
/**
* @brief Extract labels, scores and bboxes from detection matrix, the layout is
* imageId | label | score | xmin | ymin | xmax | ymax
* @param detectData Matrix of detection value
* @param numBBoxes Number of bbox to be extracted
* @param labelVec Label of bbox
* @param scoreVec Score of bbox
* @param bboxVec Append to the vector
*/
void getBBoxFromDetectData(const real* detectData,
const size_t numBBoxes,
vector<real>& labelVec,
vector<real>& scoreVec,
vector<NormalizedBBox>& bboxVec);
/**
* @brief Extract variances from prior matrix, the layout is
* xmin1 | ymin1 | xmax1 | ymax1 | xmin1Var | ymin1Var | xmax1Var | ymax1Var ...
* @param priorData Matrix of prior value
* @param num Number to be extracted
* @param varVec Append to the vector
*/
void getBBoxVarFromPriorData(const real* priorData,
const size_t num,
vector<vector<real>>& varVec);
/**
* @brief Extract bboxes from label matrix, the layout is
* class1_1 | xmin1_1 | ymin1_1 | xmax1_1 | ymax1_1 | difficult1_1 | ...
* @param labelData Matrix of label value
* @param numBBoxes Number to be extracted
* @param bboxVec Append to the vector
*/
void getBBoxFromLabelData(const real* labelData,
const size_t numBBoxes,
vector<NormalizedBBox>& bboxVec);
/**
* @brief Match prior bbox to groundtruth bbox, the strategy is:
1. Find the most overlaped bbox pair (prior and groundtruth)
2. For rest of prior bboxes find the most overlaped groundtruth bbox
* @param priorBBoxes prior bbox
* @param gtBBoxes groundtruth bbox
* @param overlapThreshold Low boundary of overlap (judge whether matched)
* @param matchIndices For each prior bbox, groundtruth bbox index if matched
otherwise -1
* @param matchOverlaps For each prior bbox, overap with all groundtruth bboxes
*/
void matchBBox(const vector<NormalizedBBox>& priorBBoxes,
const vector<NormalizedBBox>& gtBBoxes,
real overlapThreshold,
vector<int>* matchIndices,
vector<real>* matchOverlaps);
/**
* @brief Generate positive bboxes and negative bboxes,
|positive bboxes|/|negative bboxes| is negPosRatio
* @param priorValue Prior value
* @param numPriorBBoxes Number of prior bbox
* @param gtValue Groundtruth value
* @param gtStartPosPtr Since groundtruth value stored as sequence type,
this parameter indicates start position of each record
* @param seqNum Number of sequence
* @param maxConfScore Classification score for prior bbox, used to mine
negative examples
* @param batchSize Image number
* @param overlapThreshold Low boundary of overap
* @param negOverlapThreshold Upper boundary of overap (judge negative example)
* @param negPosRatio Control number of negative bboxes
* @param matchIndicesVecPtr Save indices of matched prior bbox
* @param negIndicesVecPtr Save indices of negative prior bbox
*/
pair<size_t, size_t> generateMatchIndices(
const Matrix& priorValue,
const size_t numPriorBBoxes,
const Matrix& gtValue,
const int* gtStartPosPtr,
const size_t seqNum,
const vector<vector<real>>& maxConfScore,
const size_t batchSize,
const real overlapThreshold,
const real negOverlapThreshold,
const size_t negPosRatio,
vector<vector<int>>* matchIndicesVecPtr,
vector<vector<int>>* negIndicesVecPtr);
/**
* @brief Get max confidence score for each prior bbox
* @param confData Confidence scores, layout is
* class1 score | class2 score | ... | classN score ...
* @param batchSize Image number
* @param numPriorBBoxes Prior bbox number
* @param numClasses Classes number
* @param backgroundId Background id
* @param maxConfScoreVecPtr Ouput
*/
void getMaxConfidenceScores(const real* confData,
const size_t batchSize,
const size_t numPriorBBoxes,
const size_t numClasses,
const size_t backgroundId,
vector<vector<real>>* maxConfScoreVecPtr);
template <typename T>
bool sortScorePairDescend(const pair<real, T>& pair1,
const pair<real, T>& pair2);
template <>
bool sortScorePairDescend(const pair<real, NormalizedBBox>& pair1,
const pair<real, NormalizedBBox>& pair2);
/**
* @brief Do NMS for bboxes to remove duplicated bboxes
* @param bboxes BBoxes to apply NMS
* @param confScoreData Confidence scores
* @param classIdx Class to do NMS
* @param topK Number to keep
* @param confThreshold Low boundary of confidence score
* @param nmsThreshold Threshold of overlap
* @param numPriorBBoxes Total number of prior bboxes
* @param numClasses Total class number
* @param indices Indices of high quality bboxes
*/
void applyNMSFast(const vector<NormalizedBBox>& bboxes,
const real* confScoreData,
size_t classIdx,
size_t topK,
real confThreshold,
real nmsThreshold,
size_t numPriorBBoxes,
size_t numClasses,
vector<size_t>* indices);
/**
* @brief Get detection results which satify requirements
* @param numPriorBBoxes Prior bbox number
* @param numClasses Class number
* @param backgroundId Background class
* @param batchSize Image number
* @param confThreshold Threshold of class confidence
* @param nmsTopK Used in NMS operation to keep top k bbox
* @param nmsThreshold Used in NMS, threshold of overlap
* @param keepTopK How many bboxes keeped in an image
* @param allDecodedBBoxes Decoded bboxes for all images
* @param allDetectionIndices Save detection bbox indices
*/
size_t getDetectionIndices(
const real* confData,
const size_t numPriorBBoxes,
const size_t numClasses,
const size_t backgroundId,
const size_t batchSize,
const size_t confThreshold,
const size_t nmsTopK,
const real nmsThreshold,
const size_t keepTopK,
const vector<vector<NormalizedBBox>>& allDecodedBBoxes,
vector<map<size_t, vector<size_t>>>* allDetectionIndices);
/**
* @brief Get detection results
* @param confData Confidence scores
* @param numPriorBBoxes Prior bbox number
* @param numClasses Class number
* @param batchSize Image number
* @param allIndices Indices of predicted bboxes
* @param allDecodedBBoxes BBoxes decoded
* @param out Output matrix
* image number | label | confidence score | xMin | yMin | xMax | yMax
*/
void getDetectionOutput(const real* confData,
const size_t numKept,
const size_t numPriorBBoxes,
const size_t numClasses,
const size_t batchSize,
const vector<map<size_t, vector<size_t>>>& allIndices,
const vector<vector<NormalizedBBox>>& allDecodedBBoxes,
Matrix& out);
NormalizedBBox clipBBox(const NormalizedBBox& bbox);
} // namespace paddle
...@@ -4,6 +4,7 @@ set(TRAINER_SOURCES ...@@ -4,6 +4,7 @@ set(TRAINER_SOURCES
ParameterUpdater.cpp ParameterUpdater.cpp
ParamUtil.cpp ParamUtil.cpp
RemoteParameterUpdater.cpp RemoteParameterUpdater.cpp
NewRemoteParameterUpdater.cpp
Tester.cpp Tester.cpp
Trainer.cpp Trainer.cpp
TrainerInternal.cpp TrainerInternal.cpp
...@@ -16,6 +17,7 @@ set(TRAINER_HEADERS ...@@ -16,6 +17,7 @@ set(TRAINER_HEADERS
ParameterUpdater.h ParameterUpdater.h
ParamUtil.h ParamUtil.h
RemoteParameterUpdater.h RemoteParameterUpdater.h
NewRemoteParameterUpdater.h
Tester.h Tester.h
TesterConfig.h TesterConfig.h
Trainer.h Trainer.h
...@@ -32,7 +34,7 @@ add_style_check_target(paddle_trainer_lib ...@@ -32,7 +34,7 @@ add_style_check_target(paddle_trainer_lib
add_style_check_target(paddle_trainer_lib add_style_check_target(paddle_trainer_lib
${TRAINER_HEADERS}) ${TRAINER_HEADERS})
add_dependencies(paddle_trainer_lib add_dependencies(paddle_trainer_lib
gen_proto_cpp) gen_proto_cpp paddle_pserver_cclient_lib)
macro(add_paddle_exe TARGET_NAME) macro(add_paddle_exe TARGET_NAME)
add_executable(${TARGET_NAME} ${ARGN}) add_executable(${TARGET_NAME} ${ARGN})
...@@ -56,3 +58,10 @@ install(TARGETS paddle_trainer paddle_merge_model ...@@ -56,3 +58,10 @@ install(TARGETS paddle_trainer paddle_merge_model
set_target_properties(paddle_trainer PROPERTIES INSTALL_RPATH_USE_LINK_PATH TRUE) set_target_properties(paddle_trainer PROPERTIES INSTALL_RPATH_USE_LINK_PATH TRUE)
set_target_properties(paddle_merge_model PROPERTIES INSTALL_RPATH_USE_LINK_PATH TRUE) set_target_properties(paddle_merge_model PROPERTIES INSTALL_RPATH_USE_LINK_PATH TRUE)
if(APPLE)
set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security")
endif()
target_link_libraries(paddle_trainer ${CMAKE_CURRENT_SOURCE_DIR}/libpaddle_pserver_cclient.a)
target_link_libraries(paddle_trainer_lib ${CMAKE_CURRENT_SOURCE_DIR}/libpaddle_pserver_cclient.a)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "NewRemoteParameterUpdater.h"
#include "Trainer.h"
#include "paddle/utils/Stat.h"
DECLARE_int32(trainer_id);
DECLARE_string(save_dir);
namespace paddle {
NewRemoteParameterUpdater::NewRemoteParameterUpdater(
const OptimizationConfig &config, const std::string pserverSpec)
: parameterClient_(-1),
newParameters_(nullptr),
newGradients_(nullptr),
pserverSpec_(pserverSpec) {}
void NewRemoteParameterUpdater::init(
const std::vector<ParameterPtr> &parameters) {
ParameterUpdater::init(parameters);
for (auto &para : parameters_) {
para->getBuf(PARAMETER_VALUE)->zeroMem();
para->getBuf(PARAMETER_GRADIENT)->zeroMem();
}
// create parameter server client.
parameterClient_ = paddle_new_pserver_client((char *)pserverSpec_.c_str(),
FLAGS_trainer_id == 0);
// init new parameter and gradient.
newParameters_ = initNewParameter(PARAMETER_VALUE);
newGradients_ = initNewParameter(PARAMETER_GRADIENT);
// init parameter, one trainer will get the opportunity to int parameter and
// send them to parameter server. Others will get the initialized parameter
// from parameter server
if (paddle_begin_init_params(parameterClient_)) {
LOG(INFO) << "paddle_begin_init_params start";
for (int i = 0; i < parameterSize(); ++i) {
auto paramConfig = parameters_[i]->getConfig();
std::string bytes = paramConfig.SerializeAsString();
const char *array = bytes.data();
int size = (int)bytes.size();
paddle_init_param(
parameterClient_, *newParameters_[i], (void *)array, size);
}
paddle_finish_init_params(parameterClient_);
LOG(INFO) << "paddle_begin_init_params done";
} else {
paddle_get_params(parameterClient_, newParameters_, parameterSize());
}
LOG(INFO) << "NewRemoteParameterUpdater initialized";
}
void NewRemoteParameterUpdater::updateImpl(Parameter *para) {}
void NewRemoteParameterUpdater::finishBatch(real cost) {
// send gradient to parameter server.
paddle_send_grads(parameterClient_, newGradients_, parameterSize());
// get the updated parameter from parameterClient.
paddle_get_params(parameterClient_, newParameters_, parameterSize());
// clear gradient after update parameter.
for (auto &para : parameters_) {
para->getBuf(PARAMETER_GRADIENT)->zeroMem();
}
}
void NewRemoteParameterUpdater::startPass() {}
bool NewRemoteParameterUpdater::finishPass() { return true; }
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include <thread>
#include "ParameterUpdater.h"
#include "libpaddle_pserver_cclient.h"
#include "paddle/pserver/ParameterClient2.h"
#include "paddle/utils/Queue.h"
#include "paddle/utils/Util.h"
namespace paddle {
/**
* New remote parameter updater for dense parameters that use cclient of go.
*/
class NewRemoteParameterUpdater : public ParameterUpdater {
public:
NewRemoteParameterUpdater(const OptimizationConfig& config,
const std::string pserverSpec);
~NewRemoteParameterUpdater() {
releaseNewParameter(newParameters_);
releaseNewParameter(newGradients_);
if (parameterClient_ >= 0) paddle_pserver_client_release(parameterClient_);
}
/**
* initialize the internal parameter client and itself.
*/
virtual void init(const std::vector<ParameterPtr>& parameters);
/**
* @brief start batch
*
* @note one batch training exhibits stateful feature to help
* to do performance tuning, sgd optimization if necessary.
*/
virtual PassType startBatch(int64_t batchSize) { return PASS_TRAIN; }
/**
* send parameters to pservers and get returned parameters
* from all pservers if necessary.
*/
virtual void finishBatch(real cost);
virtual void startPass();
virtual bool finishPass();
protected:
/**
* work need to do after finishBatch
*/
virtual void updateImpl(Parameter* para);
private:
int parameterSize() { return (int)parameters_.size(); }
/**
* init parameter of go paddle pserver cclient.
* @param new_params
* @param type
*/
paddle_parameter** initNewParameter(ParameterType type) {
paddle_parameter** new_params =
(paddle_parameter**)malloc(sizeof(paddle_parameter*) * parameterSize());
for (int i = 0; i < parameterSize(); ++i) {
new_params[i] = (paddle_parameter*)malloc(sizeof(paddle_parameter));
memset(new_params[i], 0, sizeof(paddle_parameter));
}
for (int i = 0; i < parameterSize(); ++i) {
ParameterPtr param = parameters_[i];
new_params[i]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
new_params[i]->name = (char*)param->getName().c_str();
new_params[i]->content =
(unsigned char*)(param->getBuf(type).get()->getData());
new_params[i]->content_len =
(int)param->getBuf(type).get()->getSize() * sizeof(real);
}
return new_params;
}
void releaseNewParameter(paddle_parameter** newParams) {
if (newParams != nullptr) {
for (int i = 0; i < parameterSize(); ++i) {
free(newParams[i]);
}
free(newParams);
}
}
protected:
/// internal parameter client object for exchanging data with pserver
paddle_pserver_client parameterClient_;
/// the parameters for new pserver client
paddle_parameter** newParameters_;
/// the gradinets for new pserver client
paddle_parameter** newGradients_;
/// the specification of parameter server "host1:port,host1:port"
std::string pserverSpec_;
};
} // namespace paddle
...@@ -126,6 +126,7 @@ def init_config_environment( ...@@ -126,6 +126,7 @@ def init_config_environment(
g_config=TrainerConfig(), g_config=TrainerConfig(),
g_layer_map={}, g_layer_map={},
g_parameter_map={}, g_parameter_map={},
g_parameter_initializer_map={},
g_extended_config_funcs={}, g_extended_config_funcs={},
# store command args of paddle_trainer # store command args of paddle_trainer
...@@ -439,8 +440,7 @@ def model_type(name): ...@@ -439,8 +440,7 @@ def model_type(name):
@config_class @config_class
class Bias(Cfg): class Bias(Cfg):
def __init__( def __init__(self,
self,
parameter_name=None, parameter_name=None,
learning_rate=None, learning_rate=None,
momentum=None, momentum=None,
...@@ -454,7 +454,8 @@ class Bias(Cfg): ...@@ -454,7 +454,8 @@ class Bias(Cfg):
sparse_remote_update=None, sparse_remote_update=None,
gradient_clipping_threshold=None, gradient_clipping_threshold=None,
is_static=None, is_static=None,
is_shared=None, ): is_shared=None,
initializer=None):
self.add_keys(locals()) self.add_keys(locals())
...@@ -465,6 +466,7 @@ class Input(Cfg): ...@@ -465,6 +466,7 @@ class Input(Cfg):
self, self,
input_layer_name, input_layer_name,
parameter_name=None, parameter_name=None,
initializer=None,
learning_rate=None, learning_rate=None,
momentum=None, momentum=None,
decay_rate=None, decay_rate=None,
...@@ -521,6 +523,7 @@ class Projection(Input): ...@@ -521,6 +523,7 @@ class Projection(Input):
initial_std=None, initial_std=None,
initial_strategy=None, initial_strategy=None,
initial_smart=None, initial_smart=None,
initializer=None,
num_batches_regularization=None, num_batches_regularization=None,
sparse_remote_update=None, sparse_remote_update=None,
sparse_update=None, sparse_update=None,
...@@ -1494,7 +1497,8 @@ class LayerBase(object): ...@@ -1494,7 +1497,8 @@ class LayerBase(object):
gradient_clipping_threshold=bias. gradient_clipping_threshold=bias.
gradient_clipping_threshold, gradient_clipping_threshold,
is_static=bias.is_static, is_static=bias.is_static,
is_shared=bias.is_shared, ) is_shared=bias.is_shared,
initializer=bias.initializer)
if for_self: if for_self:
self.config.bias_parameter_name = bias.parameter_name self.config.bias_parameter_name = bias.parameter_name
else: else:
...@@ -1551,7 +1555,8 @@ class LayerBase(object): ...@@ -1551,7 +1555,8 @@ class LayerBase(object):
format=format, format=format,
is_static=input_config.is_static, is_static=input_config.is_static,
is_shared=input_config.is_shared, is_shared=input_config.is_shared,
update_hooks=input_config.update_hooks) update_hooks=input_config.update_hooks,
initializer=input_config.initializer)
def set_layer_size(self, size): def set_layer_size(self, size):
if self.config.size == 0: if self.config.size == 0:
...@@ -3236,7 +3241,8 @@ def Parameter(name, ...@@ -3236,7 +3241,8 @@ def Parameter(name,
need_compact=None, need_compact=None,
is_static=None, is_static=None,
is_shared=None, is_shared=None,
update_hooks=None): update_hooks=None,
initializer=None):
config_assert(name not in g_parameter_map, config_assert(name not in g_parameter_map,
'Duplicated parameter name: ' + name) 'Duplicated parameter name: ' + name)
...@@ -3324,6 +3330,11 @@ def Parameter(name, ...@@ -3324,6 +3330,11 @@ def Parameter(name,
para.update_hooks.extend(update_hooks) para.update_hooks.extend(update_hooks)
g_parameter_map[name] = para g_parameter_map[name] = para
if initializer is not None:
config_assert(
callable(initializer),
"parameter initializer should be a callable object")
g_parameter_initializer_map[name] = initializer
@config_func @config_func
......
...@@ -95,6 +95,10 @@ class ParameterAttribute(object): ...@@ -95,6 +95,10 @@ class ParameterAttribute(object):
:param sparse_update: Enable sparse update for this parameter. It will :param sparse_update: Enable sparse update for this parameter. It will
enable both local and remote sparse update. enable both local and remote sparse update.
:type sparse_update: bool :type sparse_update: bool
:param initializer: If not None, it should be a callable object which accepts
a parameter name and returns numpy array for the initial
value of the parameter
:param initializer: callable object
""" """
def __init__(self, def __init__(self,
...@@ -109,7 +113,8 @@ class ParameterAttribute(object): ...@@ -109,7 +113,8 @@ class ParameterAttribute(object):
learning_rate=None, learning_rate=None,
momentum=None, momentum=None,
gradient_clipping_threshold=None, gradient_clipping_threshold=None,
sparse_update=False): sparse_update=False,
initializer=None):
self.attr = {} self.attr = {}
if is_static: if is_static:
...@@ -161,6 +166,8 @@ class ParameterAttribute(object): ...@@ -161,6 +166,8 @@ class ParameterAttribute(object):
is_compatible_with(gradient_clipping_threshold, float): is_compatible_with(gradient_clipping_threshold, float):
self.attr['gradient_clipping_threshold'] = \ self.attr['gradient_clipping_threshold'] = \
gradient_clipping_threshold gradient_clipping_threshold
if initializer is not None:
self.attr['initializer'] = initializer
def set_default_parameter_name(self, name): def set_default_parameter_name(self, name):
""" """
......
...@@ -45,7 +45,12 @@ class Optimizer(object): ...@@ -45,7 +45,12 @@ class Optimizer(object):
return swig_api.ParameterUpdater.createRemoteUpdater( return swig_api.ParameterUpdater.createRemoteUpdater(
self.__opt_conf__, pass_num, use_sparse_updater) self.__opt_conf__, pass_num, use_sparse_updater)
def create_updater(self, is_local, num_passes, use_sparse_updater): def __create_new_remote_updater__(self, pserver_spec):
return swig_api.ParameterUpdater.createNewRemoteUpdater(
self.__opt_conf__, pserver_spec)
def create_updater(self, is_local, num_passes, use_sparse_updater,
pserver_spec):
""" """
create proper parameter_updater by configuration. create proper parameter_updater by configuration.
:param is_local: create local or remote parameter updater :param is_local: create local or remote parameter updater
...@@ -64,8 +69,12 @@ class Optimizer(object): ...@@ -64,8 +69,12 @@ class Optimizer(object):
if is_local: if is_local:
parameter_updater = self.__create_local_updater__() parameter_updater = self.__create_local_updater__()
else: else:
if pserver_spec is None:
parameter_updater = self.__create_remote_updater__( parameter_updater = self.__create_remote_updater__(
num_passes, use_sparse_updater) num_passes, use_sparse_updater)
else:
parameter_updater = self.__create_new_remote_updater__(
pserver_spec)
return parameter_updater return parameter_updater
......
import numpy as np import numpy as np
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
from paddle.proto.ParameterConfig_pb2 import ParameterConfig from paddle.proto.ParameterConfig_pb2 import ParameterConfig
import paddle.trainer.config_parser as cp
import struct import struct
import tarfile import tarfile
import cStringIO import cStringIO
...@@ -18,8 +19,11 @@ def create(layers): ...@@ -18,8 +19,11 @@ def create(layers):
""" """
topology = Topology(layers) topology = Topology(layers)
pool = Parameters() pool = Parameters()
initializers = cp.g_parameter_initializer_map
for param in topology.proto().parameters: for param in topology.proto().parameters:
pool.__append_config__(param) pool.__append_config__(param)
if param.name in initializers:
pool[param.name] = initializers[param.name](param.name)
return pool return pool
......
...@@ -11,6 +11,9 @@ except ImportError: ...@@ -11,6 +11,9 @@ except ImportError:
sys.exit(0) sys.exit(0)
import paddle.v2.parameters as parameters import paddle.v2.parameters as parameters
import paddle.v2.data_type as data_type
import paddle.v2.layer as layer
from paddle.v2.attr import ParamAttr
from paddle.proto.ParameterConfig_pb2 import ParameterConfig from paddle.proto.ParameterConfig_pb2 import ParameterConfig
import random import random
import cStringIO import cStringIO
...@@ -55,6 +58,25 @@ class TestParameters(unittest.TestCase): ...@@ -55,6 +58,25 @@ class TestParameters(unittest.TestCase):
p1 = params_dup.get(name) p1 = params_dup.get(name)
self.assertTrue(numpy.isclose(p0, p1).all()) self.assertTrue(numpy.isclose(p0, p1).all())
def test_initializer(self):
def initializer(name):
assert name == "fc.w"
mat = numpy.ones((3, 2), dtype=numpy.float32)
mat[1, 1] = 2
return mat
x = layer.data(name="x", type=data_type.dense_vector(3))
y = layer.fc(x,
size=2,
bias_attr=False,
param_attr=ParamAttr(
name="fc.w", initializer=initializer))
params = parameters.create(y)
val = params["fc.w"]
assert val.shape == (3, 2)
expected = numpy.array([[1, 1], [1, 2], [1, 1]], numpy.float32)
assert numpy.logical_and.reduce(numpy.reshape(val == expected, 6))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -49,7 +49,8 @@ class SGD(object): ...@@ -49,7 +49,8 @@ class SGD(object):
parameters, parameters,
update_equation, update_equation,
extra_layers=None, extra_layers=None,
is_local=True): is_local=True,
pserver_spec=None):
if not isinstance(parameters, v2_parameters.Parameters): if not isinstance(parameters, v2_parameters.Parameters):
raise TypeError('parameters should be parameters') raise TypeError('parameters should be parameters')
...@@ -63,6 +64,7 @@ class SGD(object): ...@@ -63,6 +64,7 @@ class SGD(object):
self.__parameters__ = parameters self.__parameters__ = parameters
self.__topology_in_proto__ = topology.proto() self.__topology_in_proto__ = topology.proto()
self.__is_local__ = is_local self.__is_local__ = is_local
self.__pserver_spec__ = pserver_spec
self.__use_sparse_updater__ = self.__topology__.use_sparse_updater() self.__use_sparse_updater__ = self.__topology__.use_sparse_updater()
# # In local mode, disable sparse_remote_update. # # In local mode, disable sparse_remote_update.
...@@ -126,7 +128,8 @@ class SGD(object): ...@@ -126,7 +128,8 @@ class SGD(object):
__check_train_args__(**locals()) __check_train_args__(**locals())
self.__parameter_updater__ = self.__optimizer__.create_updater( self.__parameter_updater__ = self.__optimizer__.create_updater(
self.__is_local__, num_passes, self.__use_sparse_updater__) self.__is_local__, num_passes, self.__use_sparse_updater__,
self.__pserver_spec__)
self.__parameter_updater__.init(self.__gradient_machine__) self.__parameter_updater__.init(self.__gradient_machine__)
self.__gradient_machine__.start() self.__gradient_machine__.start()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册