service_test.go 3.2 KB
Newer Older
1 2 3
package pserver_test

import (
4
	"io/ioutil"
D
dongzhihong 已提交
5
	"reflect"
6 7
	"sync"
	"testing"
8
	"time"
9

10
	"github.com/PaddlePaddle/Paddle/go/pserver"
11 12
)

Q
Qiao Longfei 已提交
13 14 15 16
const (
	OptimizerConfig = "./client/c/test/testdata/optimizer.pb"
)

D
dongzhihong 已提交
17
func TestServiceFull(t *testing.T) {
D
dongzhihong 已提交
18 19
	var cp pserver.Checkpoint
	s, err := pserver.NewService(0, 1, "", nil, cp)
W
wuyi05 已提交
20 21 22
	if err != nil {
		t.Error(err)
	}
23 24
	var p pserver.Parameter
	p.Name = "param_a"
D
dongzhihong 已提交
25
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
26
	p.ElementType = pserver.Int32
Q
Qiao Longfei 已提交
27
	config, err := ioutil.ReadFile(OptimizerConfig)
28
	if err != nil {
29
		t.Fatalf("read optimizer proto failed")
30 31
	}

32
	err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
33 34 35 36
	if err != nil {
		t.FailNow()
	}

D
dongzhihong 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
	var p1 pserver.Parameter
	p1.Name = "param_b"
	p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
	p1.ElementType = pserver.Float32
	err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil)
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, nil)
	if err != nil {
		t.FailNow()
	}

	var param pserver.Parameter
	err = s.GetParam("param_b", &param)
	if err != nil {
		t.FailNow()
	}

	if !reflect.DeepEqual(param, p1) {
		t.FailNow()
	}

	g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)

	err = s.SendGrad(g1, nil)
	if err != nil {
		t.FailNow()
	}
	err = s.SendGrad(g2, nil)

	if err != nil {
		t.FailNow()
	}

	var param1 pserver.Parameter
	err = s.GetParam("param_a", &param1)
	if err != nil {
		t.FailNow()
	}

	// don't compare content, since it's already changed by
	// gradient update.
	param1.Content = nil
	p.Content = nil

	if !reflect.DeepEqual(param1, p) {
		t.FailNow()
	}
}

func TestMultipleInit(t *testing.T) {
90
	s, err := pserver.NewService(0)
W
wuyi05 已提交
91 92 93 94
	if err != nil {
		t.Error(err)
	}
	err = s.FinishInitParams(0, nil)
D
dongzhihong 已提交
95 96 97 98 99 100 101 102
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, nil)
	if err.Error() != pserver.AlreadyInitialized {
		t.FailNow()
	}
103 104
}

105
func TestUninitialized(t *testing.T) {
106
	s, err := pserver.NewService(0)
W
wuyi05 已提交
107
	err = s.SendGrad(pserver.Gradient{}, nil)
H
Helin Wang 已提交
108
	if err.Error() != pserver.Uninitialized {
109 110 111 112
		t.FailNow()
	}
}

113
func TestBlockUntilInitialized(t *testing.T) {
114
	s, err := pserver.NewService(0)
W
wuyi05 已提交
115 116 117
	if err != nil {
		t.Error(err)
	}
118
	ch := make(chan struct{}, 2)
H
Helin Wang 已提交
119
	errCh := make(chan error, 2)
120 121 122
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
123 124
		var param pserver.Parameter
		err := s.GetParam("param_a", &param)
125
		if err != nil {
H
Helin Wang 已提交
126
			errCh <- err
127 128
		}
		wg.Done()
H
Helin Wang 已提交
129
		ch <- struct{}{}
130 131 132 133
	}()

	wg.Add(1)
	go func() {
134
		err := s.Save("", nil)
135
		if err != nil {
H
Helin Wang 已提交
136
			errCh <- err
137 138
		}
		wg.Done()
H
Helin Wang 已提交
139
		ch <- struct{}{}
140 141
	}()

142
	time.Sleep(50 * time.Millisecond)
143

H
Helin Wang 已提交
144 145 146 147
	select {
	case <-ch:
		// some function returned before initialization is completed.
		t.FailNow()
H
Helin Wang 已提交
148 149
	case <-errCh:
		t.FailNow()
H
Helin Wang 已提交
150 151 152
	default:
	}

153 154 155 156
	var p pserver.Parameter
	p.Name = "param_a"
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
	p.ElementType = pserver.Int32
Q
Qiao Longfei 已提交
157
	config, err := ioutil.ReadFile(OptimizerConfig)
158 159 160 161
	if err != nil {
		t.Fatalf("read optimizer proto failed")
	}
	err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
162

163 164 165 166
	if err != nil {
		t.FailNow()
	}

167
	err = s.FinishInitParams(0, nil)
168 169 170 171 172 173
	if err != nil {
		t.FailNow()
	}

	wg.Wait()
}
D
dongzhihong 已提交
174 175 176 177

func TestCheckpointSpeed(t *testing.T) {
	//TODO: test speed
}