diff --git a/go/pserver/cclient/CMakeLists.txt b/go/pserver/cclient/CMakeLists.txt index 65a38ba1add80a5425c6f135ff82b65dec811f03..b3e79ca661d0832821628b7cc6b540e17db45118 100644 --- a/go/pserver/cclient/CMakeLists.txt +++ b/go/pserver/cclient/CMakeLists.txt @@ -9,6 +9,8 @@ project(cxx_go C Go) include(golang) include(flags) +cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags) + go_library(paddle_pserver_cclient STATIC) if(WITH_TESTING) add_subdirectory(test) diff --git a/go/pserver/cclient/test/optimizer.pb.txt b/go/pserver/cclient/test/testdata/optimizer.pb.txt similarity index 100% rename from go/pserver/cclient/test/optimizer.pb.txt rename to go/pserver/cclient/test/testdata/optimizer.pb.txt diff --git a/go/pserver/optimizer.go b/go/pserver/optimizer.go index 40748d03c1da2115ffb11b2f465a7a5d688d6271..12bf055b4deb2933701ed4c8b4d058e1072ace99 100644 --- a/go/pserver/optimizer.go +++ b/go/pserver/optimizer.go @@ -4,8 +4,7 @@ package pserver // TODO(zhihong): move compile flags to cmake go_library #cgo pkg-config: protobuf #cgo CFLAGS: -I ../../ -#cgo LDFLAGS: ../../build/paddle/optimizer/libpaddle_optimizer.a ../../build/proto/libpaddle_proto.a ../../third_party/install/glog/lib/libglog.a ../../third_party/install/gtest/lib/libgtest.a ../../third_party/install/gflags/lib/libgflags.a ../../third_party/install/openblas/lib/libopenblas.a -I/usr/local/lib/ -lprotobuf -#cgo LDFLAGS: /Users/dzh/.go/src/github.com/PaddlePaddle/Paddle/build/lib/libdep.a +#cgo LDFLAGS: /Users/dzh/.go/src/github.com/PaddlePaddle/Paddle/build/go/pserver/cclient/libpaddle_go_optimizer.a #include "paddle/optimizer/optimizer.h" */ import "C" @@ -18,26 +17,50 @@ var nullPtr = unsafe.Pointer(uintptr(0)) type optimizer struct { opt *C.struct_paddle_optimizer + // used in GetParam, reconstruct Parameter from optimizer + ElementType ElementType +} + +func cArrayToSlice(p unsafe.Pointer, len int) []byte { + if p == nullPtr { + return nil + } + + // create a Go clice backed by a C array, reference: + // https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices + // + // Go garbage collector will not interact with this data, need + // to be freed properly. + return (*[1 << 30]byte)(p)[:len:len] } func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer { o := &optimizer{} p := paramWithConfigs.Param c := paramWithConfigs.Config - o.opt = C.paddle_create_optimizer(C.uchar(c), C.int(len(c)), unsafe.Pointer(p.Content), c.int(p.Length), nullPtr, 0) + buffer := &p.Content[0] + o.opt = C.paddle_create_optimizer(C.uchar(c), C.int(len(c)), unsafe.Pointer(buffer), C.int(len(p.Content)), nullPtr, 0) return o } -func (o *optimizer) UpdateParameter(p Parameter, g Gradient) error { - if p.Length != g.Length { - return fmt.Errorf("Name: %s, parameter and gradient length not match, parameter: %d, gradient: %d", p.Name, p.Length, g.Length) +func (o *optimizer) GetWeights(p *Parameter) error { + + var buffer unsafe.Pointer + buffer_len := C.paddle_optimizer_get_weights(unsafe.Pointer(o), &buffer) + if buffer_len == 0 || buffer == nullPtr { + return fmt.Errorf("parameter optimizer error : %s get failed", p.name) } + p.Content = cArrayToSlice(buffer, int(buffer_len)) + return nil +} - if p.ElementType != g.ElementType { - return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", p.Name, p.ElementType, g.ElementType) +func (o *optimizer) UpdateParameter(g Gradient) error { + if o.ElementType != g.ElementType { + return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, g.ElementType, g.ElementType) } - r := C.paddle_update_parameter(o.opt, C.paddle_element_type(p.ElementType), unsafe.Pointer(g.Content), C.int(g.Length)) + // FIXME: do we need a copy? discard g.Content by GC ok + r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), unsafe.Pointer(g.Content), C.int(len(g.Content))) if r != 0 { return fmt.Errorf("optimizer update returned error code: %d", r) } diff --git a/go/pserver/optimizer_test.go b/go/pserver/optimizer_test.go index 4930f0d95f9852c17ce867c71980a6df0888daa3..eac744b5cdb1be4f5fe593add8d836a78c9c6224 100644 --- a/go/pserver/optimizer_test.go +++ b/go/pserver/optimizer_test.go @@ -1,14 +1,22 @@ package pserver -import "testing" +import ( + "io/ioutil" + "testing" +) -func TestSGDCreateRelease(t *testing.T) { - param := pserver.ParameterWithConfig{ - Param : pserver.Parameter{Name : "a", - ElementType: , - Content: , - Length : } +func TestOptimizerCreateRelease(t *testing.T) { + p := Parameter{ + Name: "a", + ElementType: Float32, } - o := newOptimizer(sgd, 1) + p.Content = []byte{0.1, 0.3} + config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb.txt") + + param := ParameterWithConfig{ + Param: p, + Config: config, + } + o := newOptimizer(param) o.Cleanup() } diff --git a/go/pserver/service.go b/go/pserver/service.go index 32449f66b7e841314756a4a306603f23f2c94d15..d0d57136b5ef1d5378cb8efef19dcf6a5755d88a 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -107,7 +107,7 @@ func (s *Service) SendGrad(g Gradient, dummy *int) error { return fmt.Errorf("parameter: %s does not exist", g.Name) } - return o.UpdateParameter(p, g) + return o.UpdateParameter(g) } // GetParam gets parameters from the parameter server. @@ -116,7 +116,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { s.mu.Lock() defer s.mu.Unlock() - p, ok := s.paramMap[name] + opt, ok := s.optMap[name] if !ok { return fmt.Errorf("parameter: %s does not exist", name) } @@ -128,8 +128,11 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { // nature. This race condition is allowed deliberately // to save the program from making a copy of the // paramter content. - *parameter = p - return nil + p.Name = name + p.ElementType = opt.ElementType + + ok := opt.GetWeights(¶meter) + return ok } // Save tells the parameter server to save parameters. diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index 1b2626f7db8ebaf86b7de3f88bee7aac8a72a30a..b746d13e1ca71e697c464f84d915af029d37120c 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -13,9 +13,7 @@ func TestFull(t *testing.T) { s := pserver.NewService() var p pserver.Parameter p.Name = "param_a" - ElementValue := []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} - p.Content = &ElementValue[0] - p.Length = len(ElementValue) + p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.ElementType = pserver.Int32 err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil) if err != nil { @@ -24,9 +22,7 @@ func TestFull(t *testing.T) { var p1 pserver.Parameter p1.Name = "param_b" - ElementValue = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} - p1.Content = &ElementValue[0] - p1.Length = len(ElementValue) + p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} p1.ElementType = pserver.Float32 err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: nil}, nil) if err != nil { diff --git a/paddle/optimizer/CMakeLists.txt b/paddle/optimizer/CMakeLists.txt index 4536f62ec7c2c3423d91e309dee993d4212160fe..35f04789cfe6dc445973f0f922269f6f78b713a3 100644 --- a/paddle/optimizer/CMakeLists.txt +++ b/paddle/optimizer/CMakeLists.txt @@ -12,6 +12,7 @@ set(OPITMIZER_SRCS add_library(paddle_optimizer STATIC ${OPITMIZER_SRCS}) add_dependencies(paddle_optimizer gen_proto_cpp) + if(WITH_TESTING) add_simple_unittest(serialization_test) add_simple_unittest(parameter_optimizer_test)