service_test.go 2.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
package pserver_test

import (
	"reflect"
	"sync"
	"testing"

	"github.com/PaddlePaddle/Paddle/paddle/go/pserver"
)

func TestFull(t *testing.T) {
	s := pserver.NewService()
	var dummy int
	err := s.BeginInitParams(nil, &dummy)
	if err != nil {
		t.FailNow()
	}

	var p pserver.Parameter
	p.Name = "param_a"
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
	p.ElementType = pserver.Int32
	err = s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy)
	if err != nil {
		t.FailNow()
	}

	var p1 pserver.Parameter
	p1.Name = "param_b"
	p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
	p1.ElementType = pserver.Float32
	err = s.InitParam(pserver.ParameterWithConfig{p1, nil}, &dummy)
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, &dummy)
	if err != nil {
		t.FailNow()
	}

	var params []pserver.Parameter
	err = s.GetParams([]string{"param_b", "param_a"}, &params)
	if err != nil {
		t.FailNow()
	}

	if len(params) != 2 || !reflect.DeepEqual(params[0], p1) || !reflect.DeepEqual(params[0], p1) {
		t.FailNow()
	}

	grads := []pserver.Gradient{pserver.Gradient(p1), pserver.Gradient(p)}
	err = s.SendGrads(grads, &dummy)
	if err != nil {
		t.FailNow()
	}

	var params1 []pserver.Parameter
	err = s.GetParams([]string{"param_b", "param_a"}, &params1)
	if err != nil {
		t.FailNow()
	}

	if len(params) != 2 {
		t.FailNow()
	}

H
Helin Wang 已提交
68 69
	// don't compare content, since it's already changed by
	// gradient update.
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
	params1[0].Content = nil
	params1[0].Content = nil
	p.Content = nil
	p1.Content = nil

	if !reflect.DeepEqual(params1[0], p1) || !reflect.DeepEqual(params1[0], p1) {
		t.FailNow()
	}
}

func TestMultipleInit(t *testing.T) {
	s := pserver.NewService()
	var dummy int
	err := s.BeginInitParams(nil, &dummy)
	if err != nil {
		t.FailNow()
	}

	// this is fine, it's possible for client to call init
	// multiple times.
	err = s.BeginInitParams(nil, &dummy)
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, &dummy)
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, &dummy)
	if err != pserver.ErrAlreadyIntialized {
		t.FailNow()
	}

	err = s.BeginInitParams(nil, &dummy)
	if err != pserver.ErrAlreadyIntialized {
		t.FailNow()
	}
}

func TestUninitialized(t *testing.T) {
	s := pserver.NewService()
	var dummy int
	err := s.SendGrads(nil, &dummy)
	if err != pserver.ErrUnintialized {
		t.FailNow()
	}
}

func TestBlockUntilInitialized(t *testing.T) {
	s := pserver.NewService()
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
		var params []pserver.Parameter
		err := s.GetParams(nil, &params)
		if err != nil {
			t.FailNow()
		}
		wg.Done()
	}()

	wg.Add(1)
	go func() {
		var dummy int
		err := s.SaveModel("", &dummy)
		if err != nil {
			t.FailNow()
		}
		wg.Done()
	}()

	var dummy int
	err := s.BeginInitParams(nil, &dummy)
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, &dummy)
	if err != nil {
		t.FailNow()
	}

	wg.Wait()
}