service_test.go 3.2 KB
Newer Older
1 2 3
package pserver_test

import (
4
	"io/ioutil"
D
dongzhihong 已提交
5
	"reflect"
6 7
	"sync"
	"testing"
8
	"time"
9

10
	"github.com/PaddlePaddle/Paddle/go/pserver"
11 12
)

Q
Qiao Longfei 已提交
13 14 15 16
const (
	OptimizerConfig = "./client/c/test/testdata/optimizer.pb"
)

D
dongzhihong 已提交
17
func TestServiceFull(t *testing.T) {
D
dongzhihong 已提交
18 19
	var cp pserver.Checkpoint
	s, err := pserver.NewService(0, 1, "", nil, cp)
W
wuyi05 已提交
20 21 22
	if err != nil {
		t.Error(err)
	}
23 24
	var p pserver.Parameter
	p.Name = "param_a"
D
dongzhihong 已提交
25
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
26
	p.ElementType = pserver.Int32
Q
Qiao Longfei 已提交
27
	config, err := ioutil.ReadFile(OptimizerConfig)
28
	if err != nil {
29
		t.Fatalf("read optimizer proto failed")
30 31
	}

32
	err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
33 34 35 36
	if err != nil {
		t.FailNow()
	}

D
dongzhihong 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
	var p1 pserver.Parameter
	p1.Name = "param_b"
	p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
	p1.ElementType = pserver.Float32
	err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil)
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, nil)
	if err != nil {
		t.FailNow()
	}

	var param pserver.Parameter
	err = s.GetParam("param_b", &param)
	if err != nil {
		t.FailNow()
	}

	if !reflect.DeepEqual(param, p1) {
		t.FailNow()
	}

	g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)

	err = s.SendGrad(g1, nil)
	if err != nil {
		t.FailNow()
	}
	err = s.SendGrad(g2, nil)

	if err != nil {
		t.FailNow()
	}

	var param1 pserver.Parameter
	err = s.GetParam("param_a", &param1)
	if err != nil {
		t.FailNow()
	}

	// don't compare content, since it's already changed by
	// gradient update.
	param1.Content = nil
	p.Content = nil

	if !reflect.DeepEqual(param1, p) {
		t.FailNow()
	}
}

func TestMultipleInit(t *testing.T) {
D
dongzhihong 已提交
90 91
	var cp pserver.Checkpoint
	s, err := pserver.NewService(0, 1, "", nil, cp)
W
wuyi05 已提交
92 93 94 95
	if err != nil {
		t.Error(err)
	}
	err = s.FinishInitParams(0, nil)
D
dongzhihong 已提交
96 97 98 99 100 101 102 103
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, nil)
	if err.Error() != pserver.AlreadyInitialized {
		t.FailNow()
	}
104 105
}

106
func TestUninitialized(t *testing.T) {
D
dongzhihong 已提交
107 108
	var cp pserver.Checkpoint
	s, err := pserver.NewService(0, 1, "", nil, cp)
W
wuyi05 已提交
109
	err = s.SendGrad(pserver.Gradient{}, nil)
H
Helin Wang 已提交
110
	if err.Error() != pserver.Uninitialized {
111 112 113 114
		t.FailNow()
	}
}

115
func TestBlockUntilInitialized(t *testing.T) {
D
dongzhihong 已提交
116 117
	var cp pserver.Checkpoint
	s, err := pserver.NewService(0, 1, "", nil, cp)
W
wuyi05 已提交
118 119 120
	if err != nil {
		t.Error(err)
	}
121
	ch := make(chan struct{}, 2)
H
Helin Wang 已提交
122
	errCh := make(chan error, 2)
123 124 125
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
126 127
		var param pserver.Parameter
		err := s.GetParam("param_a", &param)
128
		if err != nil {
H
Helin Wang 已提交
129
			errCh <- err
130 131
		}
		wg.Done()
H
Helin Wang 已提交
132
		ch <- struct{}{}
133 134
	}()

135
	time.Sleep(50 * time.Millisecond)
136

H
Helin Wang 已提交
137 138 139 140
	select {
	case <-ch:
		// some function returned before initialization is completed.
		t.FailNow()
H
Helin Wang 已提交
141 142
	case <-errCh:
		t.FailNow()
H
Helin Wang 已提交
143 144 145
	default:
	}

146 147 148 149
	var p pserver.Parameter
	p.Name = "param_a"
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
	p.ElementType = pserver.Int32
Q
Qiao Longfei 已提交
150
	config, err := ioutil.ReadFile(OptimizerConfig)
151 152 153 154
	if err != nil {
		t.Fatalf("read optimizer proto failed")
	}
	err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
155

156 157 158 159
	if err != nil {
		t.FailNow()
	}

160
	err = s.FinishInitParams(0, nil)
161 162 163 164 165 166
	if err != nil {
		t.FailNow()
	}

	wg.Wait()
}
D
dongzhihong 已提交
167 168

func TestCheckpointSpeed(t *testing.T) {
D
dzhwinter 已提交
169
	//TODO(zhihong): test speed
D
dongzhihong 已提交
170
}