提交 12749ad5 编写于 作者: D dongzhihong

"fix cmake flags in optimizer"

上级 59b40ecb
...@@ -9,6 +9,8 @@ project(cxx_go C Go) ...@@ -9,6 +9,8 @@ project(cxx_go C Go)
include(golang) include(golang)
include(flags) include(flags)
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags)
go_library(paddle_pserver_cclient STATIC) go_library(paddle_pserver_cclient STATIC)
if(WITH_TESTING) if(WITH_TESTING)
add_subdirectory(test) add_subdirectory(test)
......
...@@ -4,8 +4,7 @@ package pserver ...@@ -4,8 +4,7 @@ package pserver
// TODO(zhihong): move compile flags to cmake go_library // TODO(zhihong): move compile flags to cmake go_library
#cgo pkg-config: protobuf #cgo pkg-config: protobuf
#cgo CFLAGS: -I ../../ #cgo CFLAGS: -I ../../
#cgo LDFLAGS: ../../build/paddle/optimizer/libpaddle_optimizer.a ../../build/proto/libpaddle_proto.a ../../third_party/install/glog/lib/libglog.a ../../third_party/install/gtest/lib/libgtest.a ../../third_party/install/gflags/lib/libgflags.a ../../third_party/install/openblas/lib/libopenblas.a -I/usr/local/lib/ -lprotobuf #cgo LDFLAGS: /Users/dzh/.go/src/github.com/PaddlePaddle/Paddle/build/go/pserver/cclient/libpaddle_go_optimizer.a
#cgo LDFLAGS: /Users/dzh/.go/src/github.com/PaddlePaddle/Paddle/build/lib/libdep.a
#include "paddle/optimizer/optimizer.h" #include "paddle/optimizer/optimizer.h"
*/ */
import "C" import "C"
...@@ -18,26 +17,50 @@ var nullPtr = unsafe.Pointer(uintptr(0)) ...@@ -18,26 +17,50 @@ var nullPtr = unsafe.Pointer(uintptr(0))
type optimizer struct { type optimizer struct {
opt *C.struct_paddle_optimizer opt *C.struct_paddle_optimizer
// used in GetParam, reconstruct Parameter from optimizer
ElementType ElementType
}
func cArrayToSlice(p unsafe.Pointer, len int) []byte {
if p == nullPtr {
return nil
}
// create a Go clice backed by a C array, reference:
// https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
//
// Go garbage collector will not interact with this data, need
// to be freed properly.
return (*[1 << 30]byte)(p)[:len:len]
} }
func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer { func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer {
o := &optimizer{} o := &optimizer{}
p := paramWithConfigs.Param p := paramWithConfigs.Param
c := paramWithConfigs.Config c := paramWithConfigs.Config
o.opt = C.paddle_create_optimizer(C.uchar(c), C.int(len(c)), unsafe.Pointer(p.Content), c.int(p.Length), nullPtr, 0) buffer := &p.Content[0]
o.opt = C.paddle_create_optimizer(C.uchar(c), C.int(len(c)), unsafe.Pointer(buffer), C.int(len(p.Content)), nullPtr, 0)
return o return o
} }
func (o *optimizer) UpdateParameter(p Parameter, g Gradient) error { func (o *optimizer) GetWeights(p *Parameter) error {
if p.Length != g.Length {
return fmt.Errorf("Name: %s, parameter and gradient length not match, parameter: %d, gradient: %d", p.Name, p.Length, g.Length) var buffer unsafe.Pointer
buffer_len := C.paddle_optimizer_get_weights(unsafe.Pointer(o), &buffer)
if buffer_len == 0 || buffer == nullPtr {
return fmt.Errorf("parameter optimizer error : %s get failed", p.name)
} }
p.Content = cArrayToSlice(buffer, int(buffer_len))
return nil
}
if p.ElementType != g.ElementType { func (o *optimizer) UpdateParameter(g Gradient) error {
return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", p.Name, p.ElementType, g.ElementType) if o.ElementType != g.ElementType {
return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, g.ElementType, g.ElementType)
} }
r := C.paddle_update_parameter(o.opt, C.paddle_element_type(p.ElementType), unsafe.Pointer(g.Content), C.int(g.Length)) // FIXME: do we need a copy? discard g.Content by GC ok
r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), unsafe.Pointer(g.Content), C.int(len(g.Content)))
if r != 0 { if r != 0 {
return fmt.Errorf("optimizer update returned error code: %d", r) return fmt.Errorf("optimizer update returned error code: %d", r)
} }
......
package pserver package pserver
import "testing" import (
"io/ioutil"
"testing"
)
func TestSGDCreateRelease(t *testing.T) { func TestOptimizerCreateRelease(t *testing.T) {
param := pserver.ParameterWithConfig{ p := Parameter{
Param : pserver.Parameter{Name : "a", Name: "a",
ElementType: , ElementType: Float32,
Content: ,
Length : }
} }
o := newOptimizer(sgd, 1) p.Content = []byte{0.1, 0.3}
config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb.txt")
param := ParameterWithConfig{
Param: p,
Config: config,
}
o := newOptimizer(param)
o.Cleanup() o.Cleanup()
} }
...@@ -107,7 +107,7 @@ func (s *Service) SendGrad(g Gradient, dummy *int) error { ...@@ -107,7 +107,7 @@ func (s *Service) SendGrad(g Gradient, dummy *int) error {
return fmt.Errorf("parameter: %s does not exist", g.Name) return fmt.Errorf("parameter: %s does not exist", g.Name)
} }
return o.UpdateParameter(p, g) return o.UpdateParameter(g)
} }
// GetParam gets parameters from the parameter server. // GetParam gets parameters from the parameter server.
...@@ -116,7 +116,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { ...@@ -116,7 +116,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
p, ok := s.paramMap[name] opt, ok := s.optMap[name]
if !ok { if !ok {
return fmt.Errorf("parameter: %s does not exist", name) return fmt.Errorf("parameter: %s does not exist", name)
} }
...@@ -128,8 +128,11 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { ...@@ -128,8 +128,11 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
// nature. This race condition is allowed deliberately // nature. This race condition is allowed deliberately
// to save the program from making a copy of the // to save the program from making a copy of the
// paramter content. // paramter content.
*parameter = p p.Name = name
return nil p.ElementType = opt.ElementType
ok := opt.GetWeights(&parameter)
return ok
} }
// Save tells the parameter server to save parameters. // Save tells the parameter server to save parameters.
......
...@@ -13,9 +13,7 @@ func TestFull(t *testing.T) { ...@@ -13,9 +13,7 @@ func TestFull(t *testing.T) {
s := pserver.NewService() s := pserver.NewService()
var p pserver.Parameter var p pserver.Parameter
p.Name = "param_a" p.Name = "param_a"
ElementValue := []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.Content = &ElementValue[0]
p.Length = len(ElementValue)
p.ElementType = pserver.Int32 p.ElementType = pserver.Int32
err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil) err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
if err != nil { if err != nil {
...@@ -24,9 +22,7 @@ func TestFull(t *testing.T) { ...@@ -24,9 +22,7 @@ func TestFull(t *testing.T) {
var p1 pserver.Parameter var p1 pserver.Parameter
p1.Name = "param_b" p1.Name = "param_b"
ElementValue = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
p1.Content = &ElementValue[0]
p1.Length = len(ElementValue)
p1.ElementType = pserver.Float32 p1.ElementType = pserver.Float32
err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: nil}, nil) err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: nil}, nil)
if err != nil { if err != nil {
......
...@@ -12,6 +12,7 @@ set(OPITMIZER_SRCS ...@@ -12,6 +12,7 @@ set(OPITMIZER_SRCS
add_library(paddle_optimizer STATIC ${OPITMIZER_SRCS}) add_library(paddle_optimizer STATIC ${OPITMIZER_SRCS})
add_dependencies(paddle_optimizer gen_proto_cpp) add_dependencies(paddle_optimizer gen_proto_cpp)
if(WITH_TESTING) if(WITH_TESTING)
add_simple_unittest(serialization_test) add_simple_unittest(serialization_test)
add_simple_unittest(parameter_optimizer_test) add_simple_unittest(parameter_optimizer_test)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册