service_test.go 2.8 KB
Newer Older
1 2 3 4 5 6
package pserver_test

import (
	"reflect"
	"sync"
	"testing"
7
	"time"
8

9
	"github.com/PaddlePaddle/Paddle/go/pserver"
10 11 12 13 14 15
)

func TestFull(t *testing.T) {
	s := pserver.NewService()
	var p pserver.Parameter
	p.Name = "param_a"
D
dzhwinter 已提交
16 17 18
	ElementValue := []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
	p.Content = &ElementValue[0]
	p.Length = len(ElementValue)
19
	p.ElementType = pserver.Int32
H
Helin Wang 已提交
20
	err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
21 22 23 24 25 26
	if err != nil {
		t.FailNow()
	}

	var p1 pserver.Parameter
	p1.Name = "param_b"
D
dzhwinter 已提交
27 28 29
	ElementValue = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
	p1.Content = &ElementValue[0]
	p1.Length = len(ElementValue)
30
	p1.ElementType = pserver.Float32
H
Helin Wang 已提交
31
	err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: nil}, nil)
32 33 34 35
	if err != nil {
		t.FailNow()
	}

36
	err = s.FinishInitParams(0, nil)
37 38 39 40
	if err != nil {
		t.FailNow()
	}

41 42
	var param pserver.Parameter
	err = s.GetParam("param_b", &param)
43 44 45 46
	if err != nil {
		t.FailNow()
	}

47
	if !reflect.DeepEqual(param, p1) {
48 49 50
		t.FailNow()
	}

51
	g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)
52
	err = s.SendGrad(g1, nil)
53 54 55
	if err != nil {
		t.FailNow()
	}
56
	err = s.SendGrad(g2, nil)
57 58 59 60 61

	if err != nil {
		t.FailNow()
	}

62 63 64
	var param1 pserver.Parameter
	err = s.GetParam("param_a", &param1)
	if err != nil {
65 66 67
		t.FailNow()
	}

H
Helin Wang 已提交
68 69
	// don't compare content, since it's already changed by
	// gradient update.
70
	param1.Content = nil
71 72
	p.Content = nil

73
	if !reflect.DeepEqual(param1, p) {
74 75 76 77 78 79
		t.FailNow()
	}
}

func TestMultipleInit(t *testing.T) {
	s := pserver.NewService()
80
	err := s.FinishInitParams(0, nil)
81 82 83 84
	if err != nil {
		t.FailNow()
	}

85
	err = s.FinishInitParams(0, nil)
H
Helin Wang 已提交
86
	if err.Error() != pserver.AlreadyInitialized {
87 88 89 90
		t.FailNow()
	}
}

91 92
func TestUninitialized(t *testing.T) {
	s := pserver.NewService()
93
	err := s.SendGrad(pserver.Gradient{}, nil)
H
Helin Wang 已提交
94
	if err.Error() != pserver.Uninitialized {
95 96 97 98
		t.FailNow()
	}
}

99 100
func TestBlockUntilInitialized(t *testing.T) {
	s := pserver.NewService()
101
	ch := make(chan struct{}, 2)
H
Helin Wang 已提交
102
	errCh := make(chan error, 2)
103 104 105
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
106 107
		var param pserver.Parameter
		err := s.GetParam("param_a", &param)
108
		if err != nil {
H
Helin Wang 已提交
109
			errCh <- err
110 111
		}
		wg.Done()
H
Helin Wang 已提交
112
		ch <- struct{}{}
113 114 115 116
	}()

	wg.Add(1)
	go func() {
117
		err := s.Save("", nil)
118
		if err != nil {
H
Helin Wang 已提交
119
			errCh <- err
120 121
		}
		wg.Done()
H
Helin Wang 已提交
122
		ch <- struct{}{}
123 124
	}()

125
	time.Sleep(50 * time.Millisecond)
126

H
Helin Wang 已提交
127 128 129 130
	select {
	case <-ch:
		// some function returned before initialization is completed.
		t.FailNow()
H
Helin Wang 已提交
131 132
	case <-errCh:
		t.FailNow()
H
Helin Wang 已提交
133 134 135
	default:
	}

136 137 138 139
	var p pserver.Parameter
	p.Name = "param_a"
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
	p.ElementType = pserver.Int32
H
Helin Wang 已提交
140
	err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
141 142 143 144
	if err != nil {
		t.FailNow()
	}

145
	err = s.FinishInitParams(0, nil)
146 147 148 149 150 151
	if err != nil {
		t.FailNow()
	}

	wg.Wait()
}