service_test.go 3.2 KB
Newer Older
1 2 3
package pserver_test

import (
4
	"io/ioutil"
D
dongzhihong 已提交
5
	"reflect"
6 7
	"sync"
	"testing"
8
	"time"
9

10
	"github.com/PaddlePaddle/Paddle/go/pserver"
11 12
)

D
dongzhihong 已提交
13
func TestServiceFull(t *testing.T) {
14
	s, err := pserver.NewService(0)
W
wuyi05 已提交
15 16 17
	if err != nil {
		t.Error(err)
	}
18 19
	var p pserver.Parameter
	p.Name = "param_a"
D
dongzhihong 已提交
20
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
21
	p.ElementType = pserver.Int32
22
	config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb.txt")
23
	if err != nil {
24
		t.Fatalf("read optimizer proto failed")
25 26
	}

27
	err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
28 29 30 31
	if err != nil {
		t.FailNow()
	}

D
dongzhihong 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
	var p1 pserver.Parameter
	p1.Name = "param_b"
	p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
	p1.ElementType = pserver.Float32
	err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil)
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, nil)
	if err != nil {
		t.FailNow()
	}

	var param pserver.Parameter
	err = s.GetParam("param_b", &param)
	if err != nil {
		t.FailNow()
	}

	if !reflect.DeepEqual(param, p1) {
		t.FailNow()
	}

	g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)

	err = s.SendGrad(g1, nil)
	if err != nil {
		t.FailNow()
	}
	err = s.SendGrad(g2, nil)

	if err != nil {
		t.FailNow()
	}

	var param1 pserver.Parameter
	err = s.GetParam("param_a", &param1)
	if err != nil {
		t.FailNow()
	}

	// don't compare content, since it's already changed by
	// gradient update.
	param1.Content = nil
	p.Content = nil

	if !reflect.DeepEqual(param1, p) {
		t.FailNow()
	}
D
dongzhihong 已提交
82 83
	var dummy int
	s.Save("", &dummy)
D
dongzhihong 已提交
84 85 86
}

func TestMultipleInit(t *testing.T) {
87
	s, err := pserver.NewService(0)
W
wuyi05 已提交
88 89 90 91
	if err != nil {
		t.Error(err)
	}
	err = s.FinishInitParams(0, nil)
D
dongzhihong 已提交
92 93 94 95 96 97 98 99
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, nil)
	if err.Error() != pserver.AlreadyInitialized {
		t.FailNow()
	}
100 101
}

102
func TestUninitialized(t *testing.T) {
103
	s, err := pserver.NewService(0)
W
wuyi05 已提交
104
	err = s.SendGrad(pserver.Gradient{}, nil)
H
Helin Wang 已提交
105
	if err.Error() != pserver.Uninitialized {
106 107 108 109
		t.FailNow()
	}
}

110
func TestBlockUntilInitialized(t *testing.T) {
111
	s, err := pserver.NewService(0)
W
wuyi05 已提交
112 113 114
	if err != nil {
		t.Error(err)
	}
115
	ch := make(chan struct{}, 2)
H
Helin Wang 已提交
116
	errCh := make(chan error, 2)
117 118 119
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
120 121
		var param pserver.Parameter
		err := s.GetParam("param_a", &param)
122
		if err != nil {
H
Helin Wang 已提交
123
			errCh <- err
124 125
		}
		wg.Done()
H
Helin Wang 已提交
126
		ch <- struct{}{}
127 128 129 130
	}()

	wg.Add(1)
	go func() {
131
		err := s.Save("", nil)
132
		if err != nil {
H
Helin Wang 已提交
133
			errCh <- err
134 135
		}
		wg.Done()
H
Helin Wang 已提交
136
		ch <- struct{}{}
137 138
	}()

139
	time.Sleep(50 * time.Millisecond)
140

H
Helin Wang 已提交
141 142 143 144
	select {
	case <-ch:
		// some function returned before initialization is completed.
		t.FailNow()
H
Helin Wang 已提交
145 146
	case <-errCh:
		t.FailNow()
H
Helin Wang 已提交
147 148 149
	default:
	}

150 151 152 153
	var p pserver.Parameter
	p.Name = "param_a"
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
	p.ElementType = pserver.Int32
154 155 156 157 158
	config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb.txt")
	if err != nil {
		t.Fatalf("read optimizer proto failed")
	}
	err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
159

160 161 162 163
	if err != nil {
		t.FailNow()
	}

164
	err = s.FinishInitParams(0, nil)
165 166 167 168 169 170
	if err != nil {
		t.FailNow()
	}

	wg.Wait()
}
D
dongzhihong 已提交
171 172 173 174

func TestCheckpointSpeed(t *testing.T) {
	//TODO: test speed
}