service_test.go 2.9 KB
Newer Older
1 2 3
package pserver_test

import (
4
	"io/ioutil"
D
dongzhihong 已提交
5
	"reflect"
6 7
	"sync"
	"testing"
8
	"time"
9

10
	"github.com/PaddlePaddle/Paddle/go/pserver"
11 12
)

D
dongzhihong 已提交
13
func TestServiceFull(t *testing.T) {
14 15 16
	s := pserver.NewService()
	var p pserver.Parameter
	p.Name = "param_a"
D
dongzhihong 已提交
17
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
18
	p.ElementType = pserver.Int32
19
	config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb.txt")
20
	if err != nil {
21
		t.Fatalf("read optimizer proto failed")
22 23
	}

24
	err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
25 26 27 28
	if err != nil {
		t.FailNow()
	}

D
dongzhihong 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
	var p1 pserver.Parameter
	p1.Name = "param_b"
	p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
	p1.ElementType = pserver.Float32
	err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil)
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, nil)
	if err != nil {
		t.FailNow()
	}

	var param pserver.Parameter
	err = s.GetParam("param_b", &param)
	if err != nil {
		t.FailNow()
	}

	if !reflect.DeepEqual(param, p1) {
		t.FailNow()
	}

	g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)

	err = s.SendGrad(g1, nil)
	if err != nil {
		t.FailNow()
	}
	err = s.SendGrad(g2, nil)

	if err != nil {
		t.FailNow()
	}

	var param1 pserver.Parameter
	err = s.GetParam("param_a", &param1)
	if err != nil {
		t.FailNow()
	}

	// don't compare content, since it's already changed by
	// gradient update.
	param1.Content = nil
	p.Content = nil

	if !reflect.DeepEqual(param1, p) {
		t.FailNow()
	}
}

func TestMultipleInit(t *testing.T) {
	s := pserver.NewService()
	err := s.FinishInitParams(0, nil)
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, nil)
	if err.Error() != pserver.AlreadyInitialized {
		t.FailNow()
	}
92 93
}

94 95
func TestUninitialized(t *testing.T) {
	s := pserver.NewService()
96
	err := s.SendGrad(pserver.Gradient{}, nil)
H
Helin Wang 已提交
97
	if err.Error() != pserver.Uninitialized {
98 99 100 101
		t.FailNow()
	}
}

102 103
func TestBlockUntilInitialized(t *testing.T) {
	s := pserver.NewService()
104
	ch := make(chan struct{}, 2)
H
Helin Wang 已提交
105
	errCh := make(chan error, 2)
106 107 108
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
109 110
		var param pserver.Parameter
		err := s.GetParam("param_a", &param)
111
		if err != nil {
H
Helin Wang 已提交
112
			errCh <- err
113 114
		}
		wg.Done()
H
Helin Wang 已提交
115
		ch <- struct{}{}
116 117 118 119
	}()

	wg.Add(1)
	go func() {
120
		err := s.Save("", nil)
121
		if err != nil {
H
Helin Wang 已提交
122
			errCh <- err
123 124
		}
		wg.Done()
H
Helin Wang 已提交
125
		ch <- struct{}{}
126 127
	}()

128
	time.Sleep(50 * time.Millisecond)
129

H
Helin Wang 已提交
130 131 132 133
	select {
	case <-ch:
		// some function returned before initialization is completed.
		t.FailNow()
H
Helin Wang 已提交
134 135
	case <-errCh:
		t.FailNow()
H
Helin Wang 已提交
136 137 138
	default:
	}

139 140 141 142
	var p pserver.Parameter
	p.Name = "param_a"
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
	p.ElementType = pserver.Int32
143 144 145 146 147
	config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb.txt")
	if err != nil {
		t.Fatalf("read optimizer proto failed")
	}
	err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
148 149 150 151
	if err != nil {
		t.FailNow()
	}

152
	err = s.FinishInitParams(0, nil)
153 154 155 156 157 158
	if err != nil {
		t.FailNow()
	}

	wg.Wait()
}