提交 11660eab 编写于 作者: H Helin Wang

Fix optimizer parameter buffer allocation size.

The buffer allocation size should be number of bytes, not number of
floats.
上级 3c5cc644
......@@ -100,13 +100,13 @@ func (l lister) List() []client.Server {
return l
}
func ClientTest(t *testing.T, c *client.Client) {
func testClient(t *testing.T, c *client.Client) {
selected := c.BeginInitParams()
if !selected {
t.Fatal("should be selected.")
}
const numParameter = 100
const numParameter = 1000
config, err := ioutil.ReadFile("./c/test/testdata/optimizer.pb")
if err != nil {
t.Fatalf("read optimizer proto failed")
......@@ -128,7 +128,7 @@ func ClientTest(t *testing.T, c *client.Client) {
}
var grads []pserver.Gradient
for i := 0; i < numParameter/2; i++ {
for i := 0; i < numParameter; i++ {
var g pserver.Gradient
g.Name = "p_" + strconv.Itoa(i)
g.ElementType = pserver.Float32
......@@ -169,13 +169,14 @@ func TestNativeClient(t *testing.T) {
servers[i] = client.Server{Index: i, Addr: ":" + strconv.Itoa(pserverClientPorts[i])}
}
c1 := client.NewClient(lister(servers), len(servers), selector(true))
ClientTest(t, c1)
testClient(t, c1)
}
// TODO: tmperary disable etcdClient test for dependency of etcd)
// EtcdClient is a disabled test, since we have not embedded etcd into
// our test.
func EtcdClient(t *testing.T) {
initEtcdClient()
etcdClient := client.NewEtcd(etcdEndpoints)
c2 := client.NewClient(etcdClient, etcdClient.Desired(), selector(true))
ClientTest(t, c2)
testClient(t, c2)
}
......@@ -19,6 +19,7 @@ var nullPtr = unsafe.Pointer(uintptr(0))
type optimizer struct {
opt *C.struct_paddle_optimizer
elementType ElementType
contentLen int
}
func cArrayToSlice(p unsafe.Pointer, len int) []byte {
......@@ -37,10 +38,11 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer {
o := &optimizer{}
o.elementType = paramWithConfigs.Param.ElementType
o.contentLen = len(paramWithConfigs.Param.Content)
p := paramWithConfigs.Param
c := paramWithConfigs.Config
s := State
paramBufferSize := C.size_t(len(p.Content) / C.sizeof_float)
paramBufferSize := C.size_t(len(p.Content))
log.WithFields(log.Fields{
"ElementType": p.ElementType,
"ParamSize": paramBufferSize,
......@@ -78,7 +80,11 @@ func (o *optimizer) UpdateParameter(g Gradient) error {
return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.elementType, g.ElementType)
}
r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content))/C.sizeof_float)
if o.contentLen != len(g.Content) {
return fmt.Errorf("Name: %s, parameter and gradient does not have same content len, parameter: %d, gradient: %d", g.Name, o.contentLen, len(g.Content))
}
r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content)))
if r != 0 {
return fmt.Errorf("optimizer update returned error code: %d", r)
}
......
......@@ -31,7 +31,7 @@ func TestServiceFull(t *testing.T) {
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
if err != nil {
t.FailNow()
t.Fatal(err)
}
var p1 pserver.Parameter
......@@ -40,40 +40,40 @@ func TestServiceFull(t *testing.T) {
p1.ElementType = pserver.Float32
err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil)
if err != nil {
t.FailNow()
t.Fatal(err)
}
err = s.FinishInitParams(0, nil)
if err != nil {
t.FailNow()
t.Fatal(err)
}
var param pserver.Parameter
err = s.GetParam("param_b", &param)
if err != nil {
t.FailNow()
t.Fatal(err)
}
if !reflect.DeepEqual(param, p1) {
t.FailNow()
t.Fatal("not equal:", param, p1)
}
g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)
err = s.SendGrad(g1, nil)
if err != nil {
t.FailNow()
t.Fatal(err)
}
err = s.SendGrad(g2, nil)
if err != nil {
t.FailNow()
t.Fatal(err)
}
var param1 pserver.Parameter
err = s.GetParam("param_a", &param1)
if err != nil {
t.FailNow()
t.Fatal(err)
}
// don't compare content, since it's already changed by
......@@ -82,7 +82,7 @@ func TestServiceFull(t *testing.T) {
p.Content = nil
if !reflect.DeepEqual(param1, p) {
t.FailNow()
t.Fatal("not equal:", param1, p)
}
}
......@@ -90,16 +90,16 @@ func TestMultipleInit(t *testing.T) {
var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil {
t.Error(err)
t.Fatal(err)
}
err = s.FinishInitParams(0, nil)
if err != nil {
t.FailNow()
t.Fatal(err)
}
err = s.FinishInitParams(0, nil)
if err.Error() != pserver.AlreadyInitialized {
t.FailNow()
t.Fatal(err)
}
}
......@@ -108,7 +108,7 @@ func TestUninitialized(t *testing.T) {
s, err := pserver.NewService(0, 1, "", nil, cp)
err = s.SendGrad(pserver.Gradient{}, nil)
if err.Error() != pserver.Uninitialized {
t.FailNow()
t.Fatal(err)
}
}
......@@ -154,12 +154,12 @@ func TestBlockUntilInitialized(t *testing.T) {
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
if err != nil {
t.FailNow()
t.Fatal(err)
}
err = s.FinishInitParams(0, nil)
if err != nil {
t.FailNow()
t.Fatal(err)
}
wg.Wait()
......
......@@ -44,8 +44,8 @@ paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto,
const int state_len) {
paddle_optimizer* optimizer = new paddle_optimizer;
std::string config(config_proto, config_proto + config_proto_len);
Tensor* parameter =
new Tensor(reinterpret_cast<float*>(param_buffer), num_bytes);
Tensor* parameter = new Tensor(reinterpret_cast<float*>(param_buffer),
num_bytes / sizeof(float));
optimizer->impl = ParameterOptimizer::Create(config, parameter);
if (state != nullptr) {
std::string s(state, state + state_len);
......@@ -65,7 +65,8 @@ int paddle_update_parameter(paddle_optimizer* o,
int num_bytes) {
// TOOD(zhihong): datatype not work. need to add the runtime datatype
auto grad_type = reinterpret_cast<const float*>(grad_buffer);
Tensor* gradient = new Tensor(const_cast<float*>(grad_type), num_bytes);
Tensor* gradient =
new Tensor(const_cast<float*>(grad_type), num_bytes / sizeof(float));
o->impl->Update(gradient);
return PADDLE_SUCCESS;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册