service_test.go 3.1 KB
Newer Older
1 2 3
package pserver_test

import (
4
	"io/ioutil"
D
dongzhihong 已提交
5
	"reflect"
6 7
	"sync"
	"testing"
8
	"time"
9

10
	"github.com/PaddlePaddle/Paddle/go/pserver"
11 12
)

Q
Qiao Longfei 已提交
13 14 15 16
const (
	OptimizerConfig = "./client/c/test/testdata/optimizer.pb"
)

D
dongzhihong 已提交
17
func TestServiceFull(t *testing.T) {
18
	s, err := pserver.NewService(0)
W
wuyi05 已提交
19 20 21
	if err != nil {
		t.Error(err)
	}
22 23
	var p pserver.Parameter
	p.Name = "param_a"
D
dongzhihong 已提交
24
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
25
	p.ElementType = pserver.Int32
Q
Qiao Longfei 已提交
26
	config, err := ioutil.ReadFile(OptimizerConfig)
27
	if err != nil {
28
		t.Fatalf("read optimizer proto failed")
29 30
	}

31
	err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
32 33 34 35
	if err != nil {
		t.FailNow()
	}

D
dongzhihong 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
	var p1 pserver.Parameter
	p1.Name = "param_b"
	p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
	p1.ElementType = pserver.Float32
	err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil)
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, nil)
	if err != nil {
		t.FailNow()
	}

	var param pserver.Parameter
	err = s.GetParam("param_b", &param)
	if err != nil {
		t.FailNow()
	}

	if !reflect.DeepEqual(param, p1) {
		t.FailNow()
	}

	g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)

	err = s.SendGrad(g1, nil)
	if err != nil {
		t.FailNow()
	}
	err = s.SendGrad(g2, nil)

	if err != nil {
		t.FailNow()
	}

	var param1 pserver.Parameter
	err = s.GetParam("param_a", &param1)
	if err != nil {
		t.FailNow()
	}

	// don't compare content, since it's already changed by
	// gradient update.
	param1.Content = nil
	p.Content = nil

	if !reflect.DeepEqual(param1, p) {
		t.FailNow()
	}
}

func TestMultipleInit(t *testing.T) {
89
	s, err := pserver.NewService(0)
W
wuyi05 已提交
90 91 92 93
	if err != nil {
		t.Error(err)
	}
	err = s.FinishInitParams(0, nil)
D
dongzhihong 已提交
94 95 96 97 98 99 100 101
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, nil)
	if err.Error() != pserver.AlreadyInitialized {
		t.FailNow()
	}
102 103
}

104
func TestUninitialized(t *testing.T) {
105
	s, err := pserver.NewService(0)
W
wuyi05 已提交
106
	err = s.SendGrad(pserver.Gradient{}, nil)
H
Helin Wang 已提交
107
	if err.Error() != pserver.Uninitialized {
108 109 110 111
		t.FailNow()
	}
}

112
func TestBlockUntilInitialized(t *testing.T) {
113
	s, err := pserver.NewService(0)
W
wuyi05 已提交
114 115 116
	if err != nil {
		t.Error(err)
	}
117
	ch := make(chan struct{}, 2)
H
Helin Wang 已提交
118
	errCh := make(chan error, 2)
119 120 121
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
122 123
		var param pserver.Parameter
		err := s.GetParam("param_a", &param)
124
		if err != nil {
H
Helin Wang 已提交
125
			errCh <- err
126 127
		}
		wg.Done()
H
Helin Wang 已提交
128
		ch <- struct{}{}
129 130 131 132
	}()

	wg.Add(1)
	go func() {
133
		err := s.Save("", nil)
134
		if err != nil {
H
Helin Wang 已提交
135
			errCh <- err
136 137
		}
		wg.Done()
H
Helin Wang 已提交
138
		ch <- struct{}{}
139 140
	}()

141
	time.Sleep(50 * time.Millisecond)
142

H
Helin Wang 已提交
143 144 145 146
	select {
	case <-ch:
		// some function returned before initialization is completed.
		t.FailNow()
H
Helin Wang 已提交
147 148
	case <-errCh:
		t.FailNow()
H
Helin Wang 已提交
149 150 151
	default:
	}

152 153 154 155
	var p pserver.Parameter
	p.Name = "param_a"
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
	p.ElementType = pserver.Int32
Q
Qiao Longfei 已提交
156
	config, err := ioutil.ReadFile(OptimizerConfig)
157 158 159 160
	if err != nil {
		t.Fatalf("read optimizer proto failed")
	}
	err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
161

162 163 164 165
	if err != nil {
		t.FailNow()
	}

166
	err = s.FinishInitParams(0, nil)
167 168 169 170 171 172
	if err != nil {
		t.FailNow()
	}

	wg.Wait()
}