service_test.go 2.8 KB
Newer Older
1 2 3 4 5 6
package pserver_test

import (
	"reflect"
	"sync"
	"testing"
7
	"time"
8

9
	"github.com/PaddlePaddle/Paddle/go/pserver"
10 11 12
)

func TestFull(t *testing.T) {
H
Helin Wang 已提交
13
	s, err := pserver.NewService("", 1, time.Second*5)
W
wuyi05 已提交
14 15 16
	if err != nil {
		t.Error(err)
	}
17 18 19 20
	var p pserver.Parameter
	p.Name = "param_a"
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
	p.ElementType = pserver.Int32
W
wuyi05 已提交
21
	err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
22 23 24 25 26 27 28 29
	if err != nil {
		t.FailNow()
	}

	var p1 pserver.Parameter
	p1.Name = "param_b"
	p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
	p1.ElementType = pserver.Float32
H
Helin Wang 已提交
30
	err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: nil}, nil)
31 32 33 34
	if err != nil {
		t.FailNow()
	}

35
	err = s.FinishInitParams(0, nil)
36 37 38 39
	if err != nil {
		t.FailNow()
	}

40 41
	var param pserver.Parameter
	err = s.GetParam("param_b", &param)
42 43 44 45
	if err != nil {
		t.FailNow()
	}

46
	if !reflect.DeepEqual(param, p1) {
47 48 49
		t.FailNow()
	}

50
	g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)
51
	err = s.SendGrad(g1, nil)
52 53 54
	if err != nil {
		t.FailNow()
	}
55
	err = s.SendGrad(g2, nil)
56 57 58 59 60

	if err != nil {
		t.FailNow()
	}

61 62 63
	var param1 pserver.Parameter
	err = s.GetParam("param_a", &param1)
	if err != nil {
64 65 66
		t.FailNow()
	}

H
Helin Wang 已提交
67 68
	// don't compare content, since it's already changed by
	// gradient update.
69
	param1.Content = nil
70 71
	p.Content = nil

72
	if !reflect.DeepEqual(param1, p) {
73 74 75 76 77
		t.FailNow()
	}
}

func TestMultipleInit(t *testing.T) {
H
Helin Wang 已提交
78
	s, err := pserver.NewService("", 1, time.Second*5)
W
wuyi05 已提交
79 80 81 82
	if err != nil {
		t.Error(err)
	}
	err = s.FinishInitParams(0, nil)
83 84 85 86
	if err != nil {
		t.FailNow()
	}

87
	err = s.FinishInitParams(0, nil)
H
Helin Wang 已提交
88
	if err.Error() != pserver.AlreadyInitialized {
89 90 91 92
		t.FailNow()
	}
}

93
func TestUninitialized(t *testing.T) {
H
Helin Wang 已提交
94
	s, err := pserver.NewService("", 1, time.Second*5)
W
wuyi05 已提交
95
	err = s.SendGrad(pserver.Gradient{}, nil)
H
Helin Wang 已提交
96
	if err.Error() != pserver.Uninitialized {
97 98 99 100
		t.FailNow()
	}
}

101
func TestBlockUntilInitialized(t *testing.T) {
H
Helin Wang 已提交
102
	s, err := pserver.NewService("", 1, time.Second*5)
W
wuyi05 已提交
103 104 105
	if err != nil {
		t.Error(err)
	}
106
	ch := make(chan struct{}, 2)
H
Helin Wang 已提交
107
	errCh := make(chan error, 2)
108 109 110
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
111 112
		var param pserver.Parameter
		err := s.GetParam("param_a", &param)
113
		if err != nil {
H
Helin Wang 已提交
114
			errCh <- err
115 116
		}
		wg.Done()
H
Helin Wang 已提交
117
		ch <- struct{}{}
118 119 120 121
	}()

	wg.Add(1)
	go func() {
122
		err := s.Save("", nil)
123
		if err != nil {
H
Helin Wang 已提交
124
			errCh <- err
125 126
		}
		wg.Done()
H
Helin Wang 已提交
127
		ch <- struct{}{}
128 129
	}()

130
	time.Sleep(50 * time.Millisecond)
131

H
Helin Wang 已提交
132 133 134 135
	select {
	case <-ch:
		// some function returned before initialization is completed.
		t.FailNow()
H
Helin Wang 已提交
136 137
	case <-errCh:
		t.FailNow()
H
Helin Wang 已提交
138 139 140
	default:
	}

141 142 143 144
	var p pserver.Parameter
	p.Name = "param_a"
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
	p.ElementType = pserver.Int32
W
wuyi05 已提交
145
	err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
146 147 148 149
	if err != nil {
		t.FailNow()
	}

150
	err = s.FinishInitParams(0, nil)
151 152 153 154 155 156
	if err != nil {
		t.FailNow()
	}

	wg.Wait()
}