service_test.go 3.3 KB
Newer Older
1 2 3
package pserver_test

import (
4
	"io/ioutil"
D
dongzhihong 已提交
5
	"reflect"
6 7
	"sync"
	"testing"
8
	"time"
9

10
	"github.com/PaddlePaddle/Paddle/go/pserver"
11 12
)

Q
Qiao Longfei 已提交
13 14 15 16
const (
	OptimizerConfig = "./client/c/test/testdata/optimizer.pb"
)

D
dongzhihong 已提交
17
func TestServiceFull(t *testing.T) {
D
dongzhihong 已提交
18 19
	var cp pserver.Checkpoint
	s, err := pserver.NewService(0, 1, "", nil, cp)
W
wuyi05 已提交
20 21 22
	if err != nil {
		t.Error(err)
	}
23 24
	var p pserver.Parameter
	p.Name = "param_a"
D
dongzhihong 已提交
25
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
26
	p.ElementType = pserver.Int32
Q
Qiao Longfei 已提交
27
	config, err := ioutil.ReadFile(OptimizerConfig)
28
	if err != nil {
29
		t.Fatalf("read optimizer proto failed")
30 31
	}

32
	err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
33
	if err != nil {
34
		t.Fatal(err)
35 36
	}

D
dongzhihong 已提交
37 38 39 40 41 42
	var p1 pserver.Parameter
	p1.Name = "param_b"
	p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
	p1.ElementType = pserver.Float32
	err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil)
	if err != nil {
43
		t.Fatal(err)
D
dongzhihong 已提交
44 45 46 47
	}

	err = s.FinishInitParams(0, nil)
	if err != nil {
48
		t.Fatal(err)
D
dongzhihong 已提交
49 50 51 52 53
	}

	var param pserver.Parameter
	err = s.GetParam("param_b", &param)
	if err != nil {
54
		t.Fatal(err)
D
dongzhihong 已提交
55 56 57
	}

	if !reflect.DeepEqual(param, p1) {
58
		t.Fatal("not equal:", param, p1)
D
dongzhihong 已提交
59 60 61 62 63 64
	}

	g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)

	err = s.SendGrad(g1, nil)
	if err != nil {
65
		t.Fatal(err)
D
dongzhihong 已提交
66 67 68 69
	}
	err = s.SendGrad(g2, nil)

	if err != nil {
70
		t.Fatal(err)
D
dongzhihong 已提交
71 72 73 74 75
	}

	var param1 pserver.Parameter
	err = s.GetParam("param_a", &param1)
	if err != nil {
76
		t.Fatal(err)
D
dongzhihong 已提交
77 78 79 80 81 82 83 84
	}

	// don't compare content, since it's already changed by
	// gradient update.
	param1.Content = nil
	p.Content = nil

	if !reflect.DeepEqual(param1, p) {
85
		t.Fatal("not equal:", param1, p)
D
dongzhihong 已提交
86 87 88 89
	}
}

func TestMultipleInit(t *testing.T) {
D
dongzhihong 已提交
90 91
	var cp pserver.Checkpoint
	s, err := pserver.NewService(0, 1, "", nil, cp)
W
wuyi05 已提交
92
	if err != nil {
93
		t.Fatal(err)
W
wuyi05 已提交
94 95
	}
	err = s.FinishInitParams(0, nil)
D
dongzhihong 已提交
96
	if err != nil {
97
		t.Fatal(err)
D
dongzhihong 已提交
98 99 100 101
	}

	err = s.FinishInitParams(0, nil)
	if err.Error() != pserver.AlreadyInitialized {
102
		t.Fatal(err)
D
dongzhihong 已提交
103
	}
104 105
}

106
func TestUninitialized(t *testing.T) {
D
dongzhihong 已提交
107 108
	var cp pserver.Checkpoint
	s, err := pserver.NewService(0, 1, "", nil, cp)
W
wuyi05 已提交
109
	err = s.SendGrad(pserver.Gradient{}, nil)
H
Helin Wang 已提交
110
	if err.Error() != pserver.Uninitialized {
111
		t.Fatal(err)
112 113 114
	}
}

115
func TestBlockUntilInitialized(t *testing.T) {
D
dongzhihong 已提交
116 117
	var cp pserver.Checkpoint
	s, err := pserver.NewService(0, 1, "", nil, cp)
W
wuyi05 已提交
118 119 120
	if err != nil {
		t.Error(err)
	}
121
	ch := make(chan struct{}, 2)
H
Helin Wang 已提交
122
	errCh := make(chan error, 2)
123 124 125
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
126 127
		var param pserver.Parameter
		err := s.GetParam("param_a", &param)
128
		if err != nil {
H
Helin Wang 已提交
129
			errCh <- err
130 131
		}
		wg.Done()
H
Helin Wang 已提交
132
		ch <- struct{}{}
133 134
	}()

135
	time.Sleep(50 * time.Millisecond)
136

H
Helin Wang 已提交
137 138 139 140
	select {
	case <-ch:
		// some function returned before initialization is completed.
		t.FailNow()
H
Helin Wang 已提交
141 142
	case <-errCh:
		t.FailNow()
H
Helin Wang 已提交
143 144 145
	default:
	}

146 147 148 149
	var p pserver.Parameter
	p.Name = "param_a"
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
	p.ElementType = pserver.Int32
Q
Qiao Longfei 已提交
150
	config, err := ioutil.ReadFile(OptimizerConfig)
151 152 153 154
	if err != nil {
		t.Fatalf("read optimizer proto failed")
	}
	err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
155

156
	if err != nil {
157
		t.Fatal(err)
158 159
	}

160
	err = s.FinishInitParams(0, nil)
161
	if err != nil {
162
		t.Fatal(err)
163 164 165 166
	}

	wg.Wait()
}
D
dongzhihong 已提交
167 168

func TestCheckpointSpeed(t *testing.T) {
D
dzhwinter 已提交
169
	//TODO(zhihong): test speed
D
dongzhihong 已提交
170
}