service_test.go 2.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
package pserver_test

import (
	"reflect"
	"sync"
	"testing"

	"github.com/PaddlePaddle/Paddle/paddle/go/pserver"
)

func TestFull(t *testing.T) {
	s := pserver.NewService()
	var dummy int
	err := s.BeginInitParams(nil, &dummy)
	if err != nil {
		t.FailNow()
	}

	var p pserver.Parameter
	p.Name = "param_a"
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
	p.ElementType = pserver.Int32
	err = s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy)
	if err != nil {
		t.FailNow()
	}

	var p1 pserver.Parameter
	p1.Name = "param_b"
	p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
	p1.ElementType = pserver.Float32
	err = s.InitParam(pserver.ParameterWithConfig{p1, nil}, &dummy)
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, &dummy)
	if err != nil {
		t.FailNow()
	}

	var params []pserver.Parameter
	err = s.GetParams([]string{"param_b", "param_a"}, &params)
	if err != nil {
		t.FailNow()
	}

	if len(params) != 2 || !reflect.DeepEqual(params[0], p1) || !reflect.DeepEqual(params[0], p1) {
		t.FailNow()
	}

	grads := []pserver.Gradient{pserver.Gradient(p1), pserver.Gradient(p)}
	err = s.SendGrads(grads, &dummy)
	if err != nil {
		t.FailNow()
	}

	var params1 []pserver.Parameter
	err = s.GetParams([]string{"param_b", "param_a"}, &params1)
	if err != nil {
		t.FailNow()
	}

	if len(params) != 2 {
		t.FailNow()
	}

H
Helin Wang 已提交
68 69
	// don't compare content, since it's already changed by
	// gradient update.
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
	params1[0].Content = nil
	params1[0].Content = nil
	p.Content = nil
	p1.Content = nil

	if !reflect.DeepEqual(params1[0], p1) || !reflect.DeepEqual(params1[0], p1) {
		t.FailNow()
	}
}

func TestMultipleInit(t *testing.T) {
	s := pserver.NewService()
	var dummy int
	err := s.BeginInitParams(nil, &dummy)
	if err != nil {
		t.FailNow()
	}

	// this is fine, it's possible for client to call init
	// multiple times.
	err = s.BeginInitParams(nil, &dummy)
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, &dummy)
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, &dummy)
H
Helin Wang 已提交
101
	if err != pserver.ErrAlreadyInitialized {
102 103 104 105
		t.FailNow()
	}

	err = s.BeginInitParams(nil, &dummy)
H
Helin Wang 已提交
106
	if err != pserver.ErrAlreadyInitialized {
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
		t.FailNow()
	}
}

func TestBlockUntilInitialized(t *testing.T) {
	s := pserver.NewService()
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
		var params []pserver.Parameter
		err := s.GetParams(nil, &params)
		if err != nil {
			t.FailNow()
		}
		wg.Done()
	}()

	wg.Add(1)
	go func() {
		var dummy int
		err := s.SaveModel("", &dummy)
		if err != nil {
			t.FailNow()
		}
		wg.Done()
	}()

H
Helin Wang 已提交
134 135 136 137 138 139 140 141 142 143
	wg.Add(1)
	go func() {
		var dummy int
		err := s.SendGrads(nil, &dummy)
		if err != nil {
			t.FailNow()
		}
		wg.Done()
	}()

144 145 146 147 148 149 150 151 152 153 154 155 156
	var dummy int
	err := s.BeginInitParams(nil, &dummy)
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, &dummy)
	if err != nil {
		t.FailNow()
	}

	wg.Wait()
}