diff --git a/go/pserver/optimizer_test.go b/go/pserver/optimizer_test.go index 368047d6f89e080016909efbc5bd090c42530bfd..49d9df5898dc69a8e50c4a712e9b04467b45a7b8 100644 --- a/go/pserver/optimizer_test.go +++ b/go/pserver/optimizer_test.go @@ -2,6 +2,7 @@ package pserver import ( "io/ioutil" + "reflect" "testing" ) @@ -22,3 +23,26 @@ func TestOptimizerCreateRelease(t *testing.T) { o := newOptimizer(param) o.Cleanup() } + +func TestOptimizerFull(t *testing.T) { + p := Parameter{ + Name: "a", + ElementType: Float32, + } + p.Content = []byte{1, 3} + config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb.txt") + if err != nil { + t.Fatalf("read optimizer proto failed") + } + param := ParameterWithConfig{ + Param: p, + Config: config, + } + o := newOptimizer(param) + g := Gradient(p) + if !reflect.DeepEqual(p.Content, o.GetWeights()) { + t.FailNow() + } + o.UpdateParameter(g) + o.Cleanup() +}