service_test.go 2.6 KB
Newer Older
1 2 3 4 5 6
package pserver_test

import (
	"reflect"
	"sync"
	"testing"
7
	"time"
8

9
	"github.com/PaddlePaddle/Paddle/go/pserver"
10 11 12 13 14 15
)

func TestFull(t *testing.T) {
	s := pserver.NewService()
	var p pserver.Parameter
	p.Name = "param_a"
D
dongzhihong 已提交
16
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
17
	p.ElementType = pserver.Int32
H
Helin Wang 已提交
18
	err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
19 20 21 22 23 24
	if err != nil {
		t.FailNow()
	}

	var p1 pserver.Parameter
	p1.Name = "param_b"
D
dongzhihong 已提交
25
	p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
26
	p1.ElementType = pserver.Float32
H
Helin Wang 已提交
27
	err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: nil}, nil)
28 29 30 31
	if err != nil {
		t.FailNow()
	}

32
	err = s.FinishInitParams(0, nil)
33 34 35 36
	if err != nil {
		t.FailNow()
	}

37 38
	var param pserver.Parameter
	err = s.GetParam("param_b", &param)
39 40 41 42
	if err != nil {
		t.FailNow()
	}

43
	if !reflect.DeepEqual(param, p1) {
44 45 46
		t.FailNow()
	}

47
	g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)
48
	err = s.SendGrad(g1, nil)
49 50 51
	if err != nil {
		t.FailNow()
	}
52
	err = s.SendGrad(g2, nil)
53 54 55 56 57

	if err != nil {
		t.FailNow()
	}

58 59 60
	var param1 pserver.Parameter
	err = s.GetParam("param_a", &param1)
	if err != nil {
61 62 63
		t.FailNow()
	}

H
Helin Wang 已提交
64 65
	// don't compare content, since it's already changed by
	// gradient update.
66
	param1.Content = nil
67 68
	p.Content = nil

69
	if !reflect.DeepEqual(param1, p) {
70 71 72 73 74 75
		t.FailNow()
	}
}

func TestMultipleInit(t *testing.T) {
	s := pserver.NewService()
76
	err := s.FinishInitParams(0, nil)
77 78 79 80
	if err != nil {
		t.FailNow()
	}

81
	err = s.FinishInitParams(0, nil)
H
Helin Wang 已提交
82
	if err.Error() != pserver.AlreadyInitialized {
83 84 85 86
		t.FailNow()
	}
}

87 88
func TestUninitialized(t *testing.T) {
	s := pserver.NewService()
89
	err := s.SendGrad(pserver.Gradient{}, nil)
H
Helin Wang 已提交
90
	if err.Error() != pserver.Uninitialized {
91 92 93 94
		t.FailNow()
	}
}

95 96
func TestBlockUntilInitialized(t *testing.T) {
	s := pserver.NewService()
97
	ch := make(chan struct{}, 2)
H
Helin Wang 已提交
98
	errCh := make(chan error, 2)
99 100 101
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
102 103
		var param pserver.Parameter
		err := s.GetParam("param_a", &param)
104
		if err != nil {
H
Helin Wang 已提交
105
			errCh <- err
106 107
		}
		wg.Done()
H
Helin Wang 已提交
108
		ch <- struct{}{}
109 110 111 112
	}()

	wg.Add(1)
	go func() {
113
		err := s.Save("", nil)
114
		if err != nil {
H
Helin Wang 已提交
115
			errCh <- err
116 117
		}
		wg.Done()
H
Helin Wang 已提交
118
		ch <- struct{}{}
119 120
	}()

121
	time.Sleep(50 * time.Millisecond)
122

H
Helin Wang 已提交
123 124 125 126
	select {
	case <-ch:
		// some function returned before initialization is completed.
		t.FailNow()
H
Helin Wang 已提交
127 128
	case <-errCh:
		t.FailNow()
H
Helin Wang 已提交
129 130 131
	default:
	}

132 133 134 135
	var p pserver.Parameter
	p.Name = "param_a"
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
	p.ElementType = pserver.Int32
H
Helin Wang 已提交
136
	err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
137 138 139 140
	if err != nil {
		t.FailNow()
	}

141
	err = s.FinishInitParams(0, nil)
142 143 144 145 146 147
	if err != nil {
		t.FailNow()
	}

	wg.Wait()
}