提交 864386d5 编写于 作者: D dongzhihong

"change log in optimizer"

上级 cebfae94
...@@ -47,7 +47,7 @@ func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer { ...@@ -47,7 +47,7 @@ func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer {
} }
func (o *optimizer) GetWeights(p *Parameter) error { func (o *optimizer) GetWeights(p *Parameter) error {
// FIXME: get weigths from optimizer has bug
var buffer unsafe.Pointer var buffer unsafe.Pointer
buffer_len := C.paddle_optimizer_get_weights(o.opt, &buffer) buffer_len := C.paddle_optimizer_get_weights(o.opt, &buffer)
if buffer_len == 0 || buffer == nullPtr { if buffer_len == 0 || buffer == nullPtr {
...@@ -59,7 +59,7 @@ func (o *optimizer) GetWeights(p *Parameter) error { ...@@ -59,7 +59,7 @@ func (o *optimizer) GetWeights(p *Parameter) error {
func (o *optimizer) UpdateParameter(g Gradient) error { func (o *optimizer) UpdateParameter(g Gradient) error {
if o.ElementType != g.ElementType { if o.ElementType != g.ElementType {
return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, g.ElementType, g.ElementType) return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.ElementType, g.ElementType)
} }
// FIXME: do we need a copy? discard g.Content by GC ok // FIXME: do we need a copy? discard g.Content by GC ok
......
...@@ -2,6 +2,7 @@ package pserver_test ...@@ -2,6 +2,7 @@ package pserver_test
import ( import (
"io/ioutil" "io/ioutil"
"reflect"
"sync" "sync"
"testing" "testing"
"time" "time"
...@@ -9,7 +10,7 @@ import ( ...@@ -9,7 +10,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
) )
func TestFull(t *testing.T) { func TestNewName(t *testing.T) {
s := pserver.NewService() s := pserver.NewService()
var p pserver.Parameter var p pserver.Parameter
p.Name = "param_a" p.Name = "param_a"
...@@ -25,69 +26,69 @@ func TestFull(t *testing.T) { ...@@ -25,69 +26,69 @@ func TestFull(t *testing.T) {
t.FailNow() t.FailNow()
} }
// var p1 pserver.Parameter var p1 pserver.Parameter
// p1.Name = "param_b" p1.Name = "param_b"
// p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
// p1.ElementType = pserver.Float32 p1.ElementType = pserver.Float32
// fmt.Println("paddle passed") err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil)
// err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil) if err != nil {
// if err != nil { t.FailNow()
// t.FailNow() }
// }
err = s.FinishInitParams(0, nil)
// err = s.FinishInitParams(0, nil) if err != nil {
// if err != nil { t.FailNow()
// t.FailNow() }
// }
var param pserver.Parameter
// var param pserver.Parameter err = s.GetParam("param_b", &param)
// err = s.GetParam("param_b", &param) if err != nil {
// if err != nil { t.FailNow()
// t.FailNow() }
// }
if !reflect.DeepEqual(param, p1) {
// if !reflect.DeepEqual(param, p1) { t.FailNow()
// t.FailNow() }
// }
g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)
// g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)
// err = s.SendGrad(g1, nil) err = s.SendGrad(g1, nil)
// if err != nil { if err != nil {
// t.FailNow() t.FailNow()
// } }
// err = s.SendGrad(g2, nil) err = s.SendGrad(g2, nil)
// if err != nil { if err != nil {
// t.FailNow() t.FailNow()
// } }
// var param1 pserver.Parameter var param1 pserver.Parameter
// err = s.GetParam("param_a", &param1) err = s.GetParam("param_a", &param1)
// if err != nil { if err != nil {
// t.FailNow() t.FailNow()
// } }
// // don't compare content, since it's already changed by // don't compare content, since it's already changed by
// // gradient update. // gradient update.
// param1.Content = nil param1.Content = nil
// p.Content = nil p.Content = nil
// if !reflect.DeepEqual(param1, p) { if !reflect.DeepEqual(param1, p) {
// t.FailNow() t.FailNow()
// } }
// } }
// func TestMultipleInit(t *testing.T) { func TestMultipleInit(t *testing.T) {
// s := pserver.NewService() s := pserver.NewService()
// err := s.FinishInitParams(0, nil) err := s.FinishInitParams(0, nil)
// if err != nil { if err != nil {
// t.FailNow() t.FailNow()
// } }
// err = s.FinishInitParams(0, nil) err = s.FinishInitParams(0, nil)
// if err.Error() != pserver.AlreadyInitialized { if err.Error() != pserver.AlreadyInitialized {
// t.FailNow() t.FailNow()
// } }
} }
func TestUninitialized(t *testing.T) { func TestUninitialized(t *testing.T) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册