From 55217c962d127e72ee88e042d2dd95cfe7375a65 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Tue, 16 May 2017 18:41:01 -0400 Subject: [PATCH] Implement Pserver RPC, gradient update logic, cgo part --- paddle/go/cmd/pserver/.gitignore | 1 + paddle/go/cmd/pserver/pserver.go | 33 ++++++ paddle/go/pserver/optimizer.c | 22 ++++ paddle/go/pserver/optimizer.go | 51 +++++++++ paddle/go/pserver/optimizer.h | 19 ++++ paddle/go/pserver/service.go | 165 ++++++++++++++++++++++++++++++ paddle/go/pserver/service_test.go | 154 ++++++++++++++++++++++++++++ 7 files changed, 445 insertions(+) create mode 100644 paddle/go/cmd/pserver/.gitignore create mode 100644 paddle/go/cmd/pserver/pserver.go create mode 100644 paddle/go/pserver/optimizer.c create mode 100644 paddle/go/pserver/optimizer.go create mode 100644 paddle/go/pserver/optimizer.h create mode 100644 paddle/go/pserver/service.go create mode 100644 paddle/go/pserver/service_test.go diff --git a/paddle/go/cmd/pserver/.gitignore b/paddle/go/cmd/pserver/.gitignore new file mode 100644 index 0000000000..fffd9adc4f --- /dev/null +++ b/paddle/go/cmd/pserver/.gitignore @@ -0,0 +1 @@ +pserver diff --git a/paddle/go/cmd/pserver/pserver.go b/paddle/go/cmd/pserver/pserver.go new file mode 100644 index 0000000000..41417875fb --- /dev/null +++ b/paddle/go/cmd/pserver/pserver.go @@ -0,0 +1,33 @@ +package main + +import ( + "flag" + "net" + "net/http" + "net/rpc" + "strconv" + + "github.com/PaddlePaddle/Paddle/paddle/go/pserver" +) + +func main() { + port := flag.Int("p", 0, "port of the pserver") + flag.Parse() + + s := pserver.NewService() + err := rpc.Register(s) + if err != nil { + panic(err) + } + + rpc.HandleHTTP() + l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) + if err != nil { + panic(err) + } + + err = http.Serve(l, nil) + if err != nil { + panic(err) + } +} diff --git a/paddle/go/pserver/optimizer.c b/paddle/go/pserver/optimizer.c new file mode 100644 index 0000000000..d83409297b --- /dev/null +++ b/paddle/go/pserver/optimizer.c @@ -0,0 +1,22 @@ +#include + +#include "optimizer.h" + +typedef struct { + double learning_rate; +} SGD_optimizer; + +paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate) { + SGD_optimizer* o = (SGD_optimizer*)malloc(sizeof(SGD_optimizer)); + o->learning_rate = learning_rate; + return (paddle_optimizer*)o; +} + +void paddle_release_optimizer(paddle_optimizer* o) { + free(o); +} + +int paddle_update_parameter(paddle_optimizer* o, void *buffer, paddle_element_type datatype, const void* gradient, int num_bytes) { + // TODO + return 0; +} diff --git a/paddle/go/pserver/optimizer.go b/paddle/go/pserver/optimizer.go new file mode 100644 index 0000000000..aa02bed3e0 --- /dev/null +++ b/paddle/go/pserver/optimizer.go @@ -0,0 +1,51 @@ +package pserver + +/* +#include "optimizer.h" +*/ +import "C" +import ( + "fmt" + "unsafe" +) + +type optimizerType int + +const ( + sgd optimizerType = iota +) + +var nullPtr = unsafe.Pointer(uintptr(0)) + +type optimizer struct { + opt *C.paddle_optimizer +} + +func newOptimizer(t optimizerType, learning_rate float64) *optimizer { + o := &optimizer{} + o.opt = C.paddle_create_SGD_optimizer(C.double(learning_rate)) + return o +} + +func (o *optimizer) UpdateParameter(p Parameter, g Gradient) error { + if len(p.Content) != len(g.Content) { + return fmt.Errorf("parameter and gradient length not match, parameter: %d, gradient: %d", len(p.Content), len(g.Content)) + } + + if p.ElementType != g.ElementType { + return fmt.Errorf("parameter and gradient element type not match, parameter: %v, gradient: %v", p.ElementType, g.ElementType) + } + + r := C.paddle_update_parameter(o.opt, unsafe.Pointer(&p.Content[0]), C.paddle_element_type(p.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content))) + if r != 0 { + return fmt.Errorf("optimier returned error code: %d", r) + } + return nil +} + +func (o *optimizer) Cleanup() { + if unsafe.Pointer(o.opt) != nullPtr { + C.paddle_release_optimizer(o.opt) + o.opt = (*C.paddle_optimizer)(nullPtr) + } +} diff --git a/paddle/go/pserver/optimizer.h b/paddle/go/pserver/optimizer.h new file mode 100644 index 0000000000..e1750ca608 --- /dev/null +++ b/paddle/go/pserver/optimizer.h @@ -0,0 +1,19 @@ +#ifndef PADDLE_PSERVER_OPTIMIZER_H +#define PADDLE_PSERVER_OPTIMIZER_H + +typedef enum { + PADDLE_ELEMENT_TYPE_INT32 = 0, + PADDLE_ELEMENT_TYPE_UINT32 = 1, + PADDLE_ELEMENT_TYPE_INT64 = 2, + PADDLE_ELEMENT_TYPE_UINT64 = 3, + PADDLE_ELEMENT_TYPE_FLOAT32 = 4, + PADDLE_ELEMENT_TYPE_FLOAT64 = 5, +} paddle_element_type; + +typedef struct paddle_optimizer paddle_optimizer; + +paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate); +void paddle_release_optimizer(paddle_optimizer* o); +int paddle_update_parameter(paddle_optimizer* o, void *buffer, paddle_element_type datatype, const void* gradient, int num_bytes); + +#endif /* PADDLE_PSERVER_OPTIMIZER_H */ diff --git a/paddle/go/pserver/service.go b/paddle/go/pserver/service.go new file mode 100644 index 0000000000..0d10da9880 --- /dev/null +++ b/paddle/go/pserver/service.go @@ -0,0 +1,165 @@ +package pserver + +import ( + "errors" + "fmt" + "sync" +) + +// ElementType is the type of elements of a Parameter. +type ElementType int + +var ErrUnintialized = errors.New("pserver not initialized") +var ErrAlreadyIntialized = errors.New("pserver already initialized") + +// Supported element types +const ( + Int32 ElementType = iota + UInt32 + Int64 + UInt64 + Float32 + Float64 +) + +// Parameter is a piece of data to sync with the parameter server. +type Parameter struct { + Name string + ElementType ElementType + Content []byte +} + +// ParameterWithConfig contains the parameter and the configuration. +type ParameterWithConfig struct { + Param Parameter + Config []byte // parameter configuration in Proto Buffer format +} + +// Gradient is the gradient of the parameter. +type Gradient Parameter + +type Service struct { + initialized chan struct{} + + mu sync.Mutex + opt *optimizer + paramMap map[string]Parameter +} + +func NewService() *Service { + s := &Service{} + s.paramMap = make(map[string]Parameter) + s.initialized = make(chan struct{}) + return s +} + +func (s *Service) BeginInitParams(config []byte, dummy *int) error { + select { + case <-s.initialized: + return ErrAlreadyIntialized + default: + } + + s.mu.Lock() + defer s.mu.Unlock() + + if s.opt != nil { + s.opt.Cleanup() + } + + // TODO(helin): parse learning rate from config + s.opt = newOptimizer(sgd, 0.01) + return nil +} + +func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error { + select { + case <-s.initialized: + return ErrAlreadyIntialized + default: + } + + // TODO(helin): parse parameter config + + s.mu.Lock() + defer s.mu.Unlock() + + // TODO(helin): check if paramWithConfigs.Param.Content is + // properly memory aligned, if not, make copy to a memory + // aligned region. + s.paramMap[paramWithConfigs.Param.Name] = paramWithConfigs.Param + return nil +} + +func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error { + select { + case <-s.initialized: + return ErrAlreadyIntialized + default: + } + + close(s.initialized) + return nil +} + +func (s *Service) SendGrads(grads []Gradient, dummy *int) error { + select { + case <-s.initialized: + default: + return ErrUnintialized + } + + s.mu.Lock() + s.mu.Unlock() + + for _, g := range grads { + if _, ok := s.paramMap[g.Name]; !ok { + return fmt.Errorf("parameter: %s does not exist", g.Name) + } + } + + var wg sync.WaitGroup + for _, g := range grads { + wg.Add(1) + go func(p Parameter, g Gradient) { + s.opt.UpdateParameter(p, g) + wg.Done() + }(s.paramMap[g.Name], g) + } + + wg.Wait() + return nil +} + +func (s *Service) GetParams(names []string, parameters *[]Parameter) error { + <-s.initialized + s.mu.Lock() + s.mu.Unlock() + + for _, n := range names { + if _, ok := s.paramMap[n]; !ok { + return fmt.Errorf("parameter: %s does not exist", n) + } + } + + *parameters = make([]Parameter, len(names)) + for i, n := range names { + // The parameter content (a byte slice) may change + // during RPC serialization due to write from other + // goroutine, we allow it since mini-batch based deep + // learning optimization methods are stochastic in + // nature. This race condition is allowed deliberately + // to save the program from making a copy of the + // paramter content. + (*parameters)[i] = s.paramMap[n] + } + + return nil +} + +func (s *Service) SaveModel(path string, dummy *int) error { + <-s.initialized + + // TODO + return nil +} diff --git a/paddle/go/pserver/service_test.go b/paddle/go/pserver/service_test.go new file mode 100644 index 0000000000..ebeff1fb89 --- /dev/null +++ b/paddle/go/pserver/service_test.go @@ -0,0 +1,154 @@ +package pserver_test + +import ( + "reflect" + "sync" + "testing" + + "github.com/PaddlePaddle/Paddle/paddle/go/pserver" +) + +func TestFull(t *testing.T) { + s := pserver.NewService() + var dummy int + err := s.BeginInitParams(nil, &dummy) + if err != nil { + t.FailNow() + } + + var p pserver.Parameter + p.Name = "param_a" + p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} + p.ElementType = pserver.Int32 + err = s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy) + if err != nil { + t.FailNow() + } + + var p1 pserver.Parameter + p1.Name = "param_b" + p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + p1.ElementType = pserver.Float32 + err = s.InitParam(pserver.ParameterWithConfig{p1, nil}, &dummy) + if err != nil { + t.FailNow() + } + + err = s.FinishInitParams(0, &dummy) + if err != nil { + t.FailNow() + } + + var params []pserver.Parameter + err = s.GetParams([]string{"param_b", "param_a"}, ¶ms) + if err != nil { + t.FailNow() + } + + if len(params) != 2 || !reflect.DeepEqual(params[0], p1) || !reflect.DeepEqual(params[0], p1) { + t.FailNow() + } + + grads := []pserver.Gradient{pserver.Gradient(p1), pserver.Gradient(p)} + err = s.SendGrads(grads, &dummy) + if err != nil { + t.FailNow() + } + + var params1 []pserver.Parameter + err = s.GetParams([]string{"param_b", "param_a"}, ¶ms1) + if err != nil { + t.FailNow() + } + + if len(params) != 2 { + t.FailNow() + } + + // we don't care the content, since it's already optimized with gradient + params1[0].Content = nil + params1[0].Content = nil + p.Content = nil + p1.Content = nil + + if !reflect.DeepEqual(params1[0], p1) || !reflect.DeepEqual(params1[0], p1) { + t.FailNow() + } +} + +func TestMultipleInit(t *testing.T) { + s := pserver.NewService() + var dummy int + err := s.BeginInitParams(nil, &dummy) + if err != nil { + t.FailNow() + } + + // this is fine, it's possible for client to call init + // multiple times. + err = s.BeginInitParams(nil, &dummy) + if err != nil { + t.FailNow() + } + + err = s.FinishInitParams(0, &dummy) + if err != nil { + t.FailNow() + } + + err = s.FinishInitParams(0, &dummy) + if err != pserver.ErrAlreadyIntialized { + t.FailNow() + } + + err = s.BeginInitParams(nil, &dummy) + if err != pserver.ErrAlreadyIntialized { + t.FailNow() + } +} + +func TestUninitialized(t *testing.T) { + s := pserver.NewService() + var dummy int + err := s.SendGrads(nil, &dummy) + if err != pserver.ErrUnintialized { + t.FailNow() + } +} + +func TestBlockUntilInitialized(t *testing.T) { + s := pserver.NewService() + var wg sync.WaitGroup + wg.Add(1) + go func() { + var params []pserver.Parameter + err := s.GetParams(nil, ¶ms) + if err != nil { + t.FailNow() + } + wg.Done() + }() + + wg.Add(1) + go func() { + var dummy int + err := s.SaveModel("", &dummy) + if err != nil { + t.FailNow() + } + wg.Done() + }() + + var dummy int + err := s.BeginInitParams(nil, &dummy) + if err != nil { + t.FailNow() + } + + err = s.FinishInitParams(0, &dummy) + if err != nil { + t.FailNow() + } + + wg.Wait() +} -- GitLab