提交 c44a94b4 编写于 作者: D dongzhihong

"fix cmake build flags"

上级 12749ad5
......@@ -123,8 +123,9 @@ func paddle_begin_init_params(client C.paddle_pserver_client) C.int {
func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, param_config unsafe.Pointer, config_len C.int) C.int {
et := pserver.ElementType(param.element_type)
name := C.GoString(param.name)
content := cArrayToSlice(unsafe.Pointer(param.content), int(param.content_len))
pc := pserver.ParameterWithConfig{
Param: pserver.Parameter{Name: name, ElementType: et, Content: param.content, Length: para.content_len},
Param: pserver.Parameter{Name: name, ElementType: et, Content: content},
Config: cArrayToSlice(param_config, int(config_len)),
}
c := get(client)
......@@ -166,7 +167,8 @@ func paddle_send_grads(client C.paddle_pserver_client, grads **C.paddle_gradient
grad := *(**C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads))))
et := pserver.ElementType(grad.element_type)
name := C.GoString(grad.name)
gs = append(gs, pserver.Gradient{Name: name, ElementType: et, Content: grad.content, Length: grad.content_len})
content := cArrayToSlice(unsafe.Pointer(grad.content), int(grad.content_len))
gs = append(gs, pserver.Gradient{Name: name, ElementType: et, Content: content})
}
c := get(client)
......@@ -223,14 +225,14 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
}
if unsafe.Pointer(param.content) != nullPtr {
if int(param.content_len) != p.Length {
if int(param.content_len) != len(p.Content) {
log.Errorf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content))
return C.PSERVER_ERROR
}
}
C.memcpy(unsafe.Pointer(param.content), unsafe.Pointer(p.Content), C.size_t(p.Length))
param.content_len = C.int(p.Length)
C.memcpy(unsafe.Pointer(param.content), unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
param.content_len = C.int(len(p.Content))
param.element_type = C.paddle_element_type(p.ElementType)
}
......
......@@ -50,10 +50,10 @@ void getParams(paddle_pserver_client c) {
int main() {
char addr[] = "localhost:3000";
paddle_pserver_client c = paddle_new_pserver_client(addr, 1);
char config_proto[1024];
char *config_proto;
size_t config_proto_len = 0;
ssize_t nread;
FILE *fp = fopen("optimizer.pb.txt", "r");
FILE *fp = fopen("testdata/optimizer.pb.txt", "r");
if(!fp) { fail(); }
while((nread = getline(&config_proto, &config_proto_len, fp)) != -1) {
printf("%s", config_proto);
......@@ -70,7 +70,7 @@ retry:
param.name = name_a;
param.content = content_a;
param.content_len = 2000;
int error = paddle_init_param(c, param, config_proto, config_proto_len);
int error = paddle_init_param(c, param, (void *)config_proto, config_proto_len);
if (error != 0) {
goto retry;
}
......@@ -79,7 +79,7 @@ retry:
param.name = name_b;
param.content = content_b;
param.content_len = 3000;
error = paddle_init_param(c, param, NULL, 0);
error = paddle_init_param(c, param, (void *)config_proto, config_proto_len);
if (error != 0) {
goto retry;
}
......
......@@ -75,9 +75,7 @@ func TestClientFull(t *testing.T) {
var p pserver.Parameter
p.Name = "p_" + strconv.Itoa(i)
p.ElementType = pserver.Float32
ElementValue := make([]byte, (i+1)*100)
p.Content = &ElementValue[0]
p.Length = len(ElementValue)
p.Content = make([]byte, (i+1)*100)
err := c.InitParam(pserver.ParameterWithConfig{Param: p})
if err != nil {
t.Fatal(err)
......@@ -94,9 +92,7 @@ func TestClientFull(t *testing.T) {
var g pserver.Gradient
g.Name = "p_" + strconv.Itoa(i)
g.ElementType = pserver.Float32
ElementValue := make([]byte, (i+1)*100)
g.Content = &ElementValue[0]
g.Length = len(ElementValue)
g.Content = make([]byte, (i+1)*100)
grads = append(grads, g)
}
......
......@@ -4,7 +4,7 @@ package pserver
// TODO(zhihong): move compile flags to cmake go_library
#cgo pkg-config: protobuf
#cgo CFLAGS: -I ../../
#cgo LDFLAGS: /Users/dzh/.go/src/github.com/PaddlePaddle/Paddle/build/go/pserver/cclient/libpaddle_go_optimizer.a
#cgo LDFLAGS: /Users/dzh/.go/src/github.com/PaddlePaddle/Paddle/build/go/pserver/cclient/libpaddle_go_optimizer.a -lstdc++
#include "paddle/optimizer/optimizer.h"
*/
import "C"
......@@ -38,17 +38,20 @@ func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer {
o := &optimizer{}
p := paramWithConfigs.Param
c := paramWithConfigs.Config
buffer := &p.Content[0]
o.opt = C.paddle_create_optimizer(C.uchar(c), C.int(len(c)), unsafe.Pointer(buffer), C.int(len(p.Content)), nullPtr, 0)
var cbuffer unsafe.Pointer
cbuffer = unsafe.Pointer(&p.Content[0])
o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)),
C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)),
(*C.char)(nullPtr), 0)
return o
}
func (o *optimizer) GetWeights(p *Parameter) error {
var buffer unsafe.Pointer
buffer_len := C.paddle_optimizer_get_weights(unsafe.Pointer(o), &buffer)
buffer_len := C.paddle_optimizer_get_weights(o.opt, &buffer)
if buffer_len == 0 || buffer == nullPtr {
return fmt.Errorf("parameter optimizer error : %s get failed", p.name)
return fmt.Errorf("parameter optimizer error : %s get failed", p.Name)
}
p.Content = cArrayToSlice(buffer, int(buffer_len))
return nil
......@@ -60,7 +63,9 @@ func (o *optimizer) UpdateParameter(g Gradient) error {
}
// FIXME: do we need a copy? discard g.Content by GC ok
r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), unsafe.Pointer(g.Content), C.int(len(g.Content)))
var cbuffer unsafe.Pointer
cbuffer = unsafe.Pointer(&g.Content[0])
r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), cbuffer, C.int(len(g.Content)))
if r != 0 {
return fmt.Errorf("optimizer update returned error code: %d", r)
}
......
......@@ -8,11 +8,13 @@ import (
func TestOptimizerCreateRelease(t *testing.T) {
p := Parameter{
Name: "a",
ElementType: Float32,
ElementType: Int32,
}
p.Content = []byte{0.1, 0.3}
p.Content = []byte{1, 3}
config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb.txt")
if err != nil {
t.Fatalf("read optimizer proto failed")
}
param := ParameterWithConfig{
Param: p,
Config: config,
......
......@@ -128,11 +128,11 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
// nature. This race condition is allowed deliberately
// to save the program from making a copy of the
// paramter content.
p.Name = name
p.ElementType = opt.ElementType
parameter.Name = name
parameter.ElementType = opt.ElementType
ok := opt.GetWeights(&parameter)
return ok
err := opt.GetWeights(parameter)
return err
}
// Save tells the parameter server to save parameters.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册