service_test.go 3.1 KB
Newer Older
1 2 3
package pserver_test

import (
4
	"io/ioutil"
D
dongzhihong 已提交
5
	"reflect"
6 7
	"sync"
	"testing"
8
	"time"
9

10
	"github.com/PaddlePaddle/Paddle/go/pserver"
11 12
)

13

14
func TestFull(t *testing.T) {
15
	s, err := pserver.NewService(0)
W
wuyi05 已提交
16 17 18
	if err != nil {
		t.Error(err)
	}
19 20
	var p pserver.Parameter
	p.Name = "param_a"
D
dongzhihong 已提交
21
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
22
	p.ElementType = pserver.Int32
23
	config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb.txt")
24
	if err != nil {
25
		t.Fatalf("read optimizer proto failed")
26 27
	}

28
	err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
29 30 31 32
	if err != nil {
		t.FailNow()
	}

D
dongzhihong 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
	var p1 pserver.Parameter
	p1.Name = "param_b"
	p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
	p1.ElementType = pserver.Float32
	err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil)
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, nil)
	if err != nil {
		t.FailNow()
	}

	var param pserver.Parameter
	err = s.GetParam("param_b", &param)
	if err != nil {
		t.FailNow()
	}

	if !reflect.DeepEqual(param, p1) {
		t.FailNow()
	}

	g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)

	err = s.SendGrad(g1, nil)
	if err != nil {
		t.FailNow()
	}
	err = s.SendGrad(g2, nil)

	if err != nil {
		t.FailNow()
	}

	var param1 pserver.Parameter
	err = s.GetParam("param_a", &param1)
	if err != nil {
		t.FailNow()
	}

	// don't compare content, since it's already changed by
	// gradient update.
	param1.Content = nil
	p.Content = nil

	if !reflect.DeepEqual(param1, p) {
		t.FailNow()
	}
}

func TestMultipleInit(t *testing.T) {
86
	s, err := pserver.NewService(0)
W
wuyi05 已提交
87 88 89 90
	if err != nil {
		t.Error(err)
	}
	err = s.FinishInitParams(0, nil)
D
dongzhihong 已提交
91 92 93 94 95 96 97 98
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, nil)
	if err.Error() != pserver.AlreadyInitialized {
		t.FailNow()
	}
99 100
}

101
func TestUninitialized(t *testing.T) {
102
	s, err := pserver.NewService(0)
W
wuyi05 已提交
103
	err = s.SendGrad(pserver.Gradient{}, nil)
H
Helin Wang 已提交
104
	if err.Error() != pserver.Uninitialized {
105 106 107 108
		t.FailNow()
	}
}

109
func TestBlockUntilInitialized(t *testing.T) {
110
	s, err := pserver.NewService(0)
W
wuyi05 已提交
111 112 113
	if err != nil {
		t.Error(err)
	}
114
	ch := make(chan struct{}, 2)
H
Helin Wang 已提交
115
	errCh := make(chan error, 2)
116 117 118
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
119 120
		var param pserver.Parameter
		err := s.GetParam("param_a", &param)
121
		if err != nil {
H
Helin Wang 已提交
122
			errCh <- err
123 124
		}
		wg.Done()
H
Helin Wang 已提交
125
		ch <- struct{}{}
126 127 128 129
	}()

	wg.Add(1)
	go func() {
130
		err := s.Save("", nil)
131
		if err != nil {
H
Helin Wang 已提交
132
			errCh <- err
133 134
		}
		wg.Done()
H
Helin Wang 已提交
135
		ch <- struct{}{}
136 137
	}()

138
	time.Sleep(50 * time.Millisecond)
139

H
Helin Wang 已提交
140 141 142 143
	select {
	case <-ch:
		// some function returned before initialization is completed.
		t.FailNow()
H
Helin Wang 已提交
144 145
	case <-errCh:
		t.FailNow()
H
Helin Wang 已提交
146 147 148
	default:
	}

149 150 151 152
	var p pserver.Parameter
	p.Name = "param_a"
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
	p.ElementType = pserver.Int32
153 154 155 156 157
	config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb.txt")
	if err != nil {
		t.Fatalf("read optimizer proto failed")
	}
	err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
158

159 160 161 162
	if err != nil {
		t.FailNow()
	}

163
	err = s.FinishInitParams(0, nil)
164 165 166 167 168 169
	if err != nil {
		t.FailNow()
	}

	wg.Wait()
}