提交 12749ad5 编写于 作者: D dongzhihong

"fix cmake flags in optimizer"

上级 59b40ecb
......@@ -9,6 +9,8 @@ project(cxx_go C Go)
include(golang)
include(flags)
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags)
go_library(paddle_pserver_cclient STATIC)
if(WITH_TESTING)
add_subdirectory(test)
......
......@@ -4,8 +4,7 @@ package pserver
// TODO(zhihong): move compile flags to cmake go_library
#cgo pkg-config: protobuf
#cgo CFLAGS: -I ../../
#cgo LDFLAGS: ../../build/paddle/optimizer/libpaddle_optimizer.a ../../build/proto/libpaddle_proto.a ../../third_party/install/glog/lib/libglog.a ../../third_party/install/gtest/lib/libgtest.a ../../third_party/install/gflags/lib/libgflags.a ../../third_party/install/openblas/lib/libopenblas.a -I/usr/local/lib/ -lprotobuf
#cgo LDFLAGS: /Users/dzh/.go/src/github.com/PaddlePaddle/Paddle/build/lib/libdep.a
#cgo LDFLAGS: /Users/dzh/.go/src/github.com/PaddlePaddle/Paddle/build/go/pserver/cclient/libpaddle_go_optimizer.a
#include "paddle/optimizer/optimizer.h"
*/
import "C"
......@@ -18,26 +17,50 @@ var nullPtr = unsafe.Pointer(uintptr(0))
type optimizer struct {
opt *C.struct_paddle_optimizer
// used in GetParam, reconstruct Parameter from optimizer
ElementType ElementType
}
func cArrayToSlice(p unsafe.Pointer, len int) []byte {
if p == nullPtr {
return nil
}
// create a Go clice backed by a C array, reference:
// https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
//
// Go garbage collector will not interact with this data, need
// to be freed properly.
return (*[1 << 30]byte)(p)[:len:len]
}
func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer {
o := &optimizer{}
p := paramWithConfigs.Param
c := paramWithConfigs.Config
o.opt = C.paddle_create_optimizer(C.uchar(c), C.int(len(c)), unsafe.Pointer(p.Content), c.int(p.Length), nullPtr, 0)
buffer := &p.Content[0]
o.opt = C.paddle_create_optimizer(C.uchar(c), C.int(len(c)), unsafe.Pointer(buffer), C.int(len(p.Content)), nullPtr, 0)
return o
}
func (o *optimizer) UpdateParameter(p Parameter, g Gradient) error {
if p.Length != g.Length {
return fmt.Errorf("Name: %s, parameter and gradient length not match, parameter: %d, gradient: %d", p.Name, p.Length, g.Length)
func (o *optimizer) GetWeights(p *Parameter) error {
var buffer unsafe.Pointer
buffer_len := C.paddle_optimizer_get_weights(unsafe.Pointer(o), &buffer)
if buffer_len == 0 || buffer == nullPtr {
return fmt.Errorf("parameter optimizer error : %s get failed", p.name)
}
p.Content = cArrayToSlice(buffer, int(buffer_len))
return nil
}
if p.ElementType != g.ElementType {
return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", p.Name, p.ElementType, g.ElementType)
func (o *optimizer) UpdateParameter(g Gradient) error {
if o.ElementType != g.ElementType {
return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, g.ElementType, g.ElementType)
}
r := C.paddle_update_parameter(o.opt, C.paddle_element_type(p.ElementType), unsafe.Pointer(g.Content), C.int(g.Length))
// FIXME: do we need a copy? discard g.Content by GC ok
r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), unsafe.Pointer(g.Content), C.int(len(g.Content)))
if r != 0 {
return fmt.Errorf("optimizer update returned error code: %d", r)
}
......
package pserver
import "testing"
import (
"io/ioutil"
"testing"
)
func TestSGDCreateRelease(t *testing.T) {
param := pserver.ParameterWithConfig{
Param : pserver.Parameter{Name : "a",
ElementType: ,
Content: ,
Length : }
func TestOptimizerCreateRelease(t *testing.T) {
p := Parameter{
Name: "a",
ElementType: Float32,
}
o := newOptimizer(sgd, 1)
p.Content = []byte{0.1, 0.3}
config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb.txt")
param := ParameterWithConfig{
Param: p,
Config: config,
}
o := newOptimizer(param)
o.Cleanup()
}
......@@ -107,7 +107,7 @@ func (s *Service) SendGrad(g Gradient, dummy *int) error {
return fmt.Errorf("parameter: %s does not exist", g.Name)
}
return o.UpdateParameter(p, g)
return o.UpdateParameter(g)
}
// GetParam gets parameters from the parameter server.
......@@ -116,7 +116,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
s.mu.Lock()
defer s.mu.Unlock()
p, ok := s.paramMap[name]
opt, ok := s.optMap[name]
if !ok {
return fmt.Errorf("parameter: %s does not exist", name)
}
......@@ -128,8 +128,11 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
// nature. This race condition is allowed deliberately
// to save the program from making a copy of the
// paramter content.
*parameter = p
return nil
p.Name = name
p.ElementType = opt.ElementType
ok := opt.GetWeights(&parameter)
return ok
}
// Save tells the parameter server to save parameters.
......
......@@ -13,9 +13,7 @@ func TestFull(t *testing.T) {
s := pserver.NewService()
var p pserver.Parameter
p.Name = "param_a"
ElementValue := []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.Content = &ElementValue[0]
p.Length = len(ElementValue)
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32
err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
if err != nil {
......@@ -24,9 +22,7 @@ func TestFull(t *testing.T) {
var p1 pserver.Parameter
p1.Name = "param_b"
ElementValue = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
p1.Content = &ElementValue[0]
p1.Length = len(ElementValue)
p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
p1.ElementType = pserver.Float32
err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: nil}, nil)
if err != nil {
......
......@@ -12,6 +12,7 @@ set(OPITMIZER_SRCS
add_library(paddle_optimizer STATIC ${OPITMIZER_SRCS})
add_dependencies(paddle_optimizer gen_proto_cpp)
if(WITH_TESTING)
add_simple_unittest(serialization_test)
add_simple_unittest(parameter_optimizer_test)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册