service_test.go 3.1 KB
Newer Older
1 2 3
package pserver_test

import (
4
	"io/ioutil"
D
dongzhihong 已提交
5
	"reflect"
6 7
	"sync"
	"testing"
8
	"time"
9

10
	"github.com/PaddlePaddle/Paddle/go/pserver"
11 12
)

D
dongzhihong 已提交
13
func TestServiceFull(t *testing.T) {
14
	s, err := pserver.NewService(0)
W
wuyi05 已提交
15 16 17
	if err != nil {
		t.Error(err)
	}
18 19
	var p pserver.Parameter
	p.Name = "param_a"
D
dongzhihong 已提交
20
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
21
	p.ElementType = pserver.Int32
D
dongzhihong 已提交
22
	config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb")
23
	if err != nil {
24
		t.Fatalf("read optimizer proto failed")
25 26
	}

27
	err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
28 29 30 31
	if err != nil {
		t.FailNow()
	}

D
dongzhihong 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
	var p1 pserver.Parameter
	p1.Name = "param_b"
	p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
	p1.ElementType = pserver.Float32
	err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil)
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, nil)
	if err != nil {
		t.FailNow()
	}

	var param pserver.Parameter
	err = s.GetParam("param_b", &param)
	if err != nil {
		t.FailNow()
	}

	if !reflect.DeepEqual(param, p1) {
		t.FailNow()
	}

	g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)

	err = s.SendGrad(g1, nil)
	if err != nil {
		t.FailNow()
	}
	err = s.SendGrad(g2, nil)

	if err != nil {
		t.FailNow()
	}

	var param1 pserver.Parameter
	err = s.GetParam("param_a", &param1)
	if err != nil {
		t.FailNow()
	}

	// don't compare content, since it's already changed by
	// gradient update.
	param1.Content = nil
	p.Content = nil

	if !reflect.DeepEqual(param1, p) {
		t.FailNow()
	}
}

func TestMultipleInit(t *testing.T) {
85
	s, err := pserver.NewService(0)
W
wuyi05 已提交
86 87 88 89
	if err != nil {
		t.Error(err)
	}
	err = s.FinishInitParams(0, nil)
D
dongzhihong 已提交
90 91 92 93 94 95 96 97
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, nil)
	if err.Error() != pserver.AlreadyInitialized {
		t.FailNow()
	}
98 99
}

100
func TestUninitialized(t *testing.T) {
101
	s, err := pserver.NewService(0)
W
wuyi05 已提交
102
	err = s.SendGrad(pserver.Gradient{}, nil)
H
Helin Wang 已提交
103
	if err.Error() != pserver.Uninitialized {
104 105 106 107
		t.FailNow()
	}
}

108
func TestBlockUntilInitialized(t *testing.T) {
109
	s, err := pserver.NewService(0)
W
wuyi05 已提交
110 111 112
	if err != nil {
		t.Error(err)
	}
113
	ch := make(chan struct{}, 2)
H
Helin Wang 已提交
114
	errCh := make(chan error, 2)
115 116 117
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
118 119
		var param pserver.Parameter
		err := s.GetParam("param_a", &param)
120
		if err != nil {
H
Helin Wang 已提交
121
			errCh <- err
122 123
		}
		wg.Done()
H
Helin Wang 已提交
124
		ch <- struct{}{}
125 126 127 128
	}()

	wg.Add(1)
	go func() {
129
		err := s.Save("", nil)
130
		if err != nil {
H
Helin Wang 已提交
131
			errCh <- err
132 133
		}
		wg.Done()
H
Helin Wang 已提交
134
		ch <- struct{}{}
135 136
	}()

137
	time.Sleep(50 * time.Millisecond)
138

H
Helin Wang 已提交
139 140 141 142
	select {
	case <-ch:
		// some function returned before initialization is completed.
		t.FailNow()
H
Helin Wang 已提交
143 144
	case <-errCh:
		t.FailNow()
H
Helin Wang 已提交
145 146 147
	default:
	}

148 149 150 151
	var p pserver.Parameter
	p.Name = "param_a"
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
	p.ElementType = pserver.Int32
D
dongzhihong 已提交
152
	config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb")
153 154 155 156
	if err != nil {
		t.Fatalf("read optimizer proto failed")
	}
	err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
157

158 159 160 161
	if err != nil {
		t.FailNow()
	}

162
	err = s.FinishInitParams(0, nil)
163 164 165 166 167 168
	if err != nil {
		t.FailNow()
	}

	wg.Wait()
}