service_test.go 2.7 KB
Newer Older
1 2 3 4 5 6
package pserver_test

import (
	"reflect"
	"sync"
	"testing"
7
	"time"
8

9
	"github.com/PaddlePaddle/Paddle/go/pserver"
10 11 12 13 14 15 16 17
)

func TestFull(t *testing.T) {
	s := pserver.NewService()
	var p pserver.Parameter
	p.Name = "param_a"
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
	p.ElementType = pserver.Int32
18
	var dummy int
19
	err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, &dummy)
20 21 22 23 24 25 26 27
	if err != nil {
		t.FailNow()
	}

	var p1 pserver.Parameter
	p1.Name = "param_b"
	p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
	p1.ElementType = pserver.Float32
28
	err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: nil}, &dummy)
29 30 31 32 33 34 35 36 37
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, &dummy)
	if err != nil {
		t.FailNow()
	}

38 39
	var param pserver.Parameter
	err = s.GetParam("param_b", &param)
40 41 42 43
	if err != nil {
		t.FailNow()
	}

44
	if !reflect.DeepEqual(param, p1) {
45 46 47
		t.FailNow()
	}

48 49
	g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)
	err = s.SendGrad(g1, &dummy)
50 51 52
	if err != nil {
		t.FailNow()
	}
53
	err = s.SendGrad(g2, &dummy)
54 55 56 57 58

	if err != nil {
		t.FailNow()
	}

59 60 61
	var param1 pserver.Parameter
	err = s.GetParam("param_a", &param1)
	if err != nil {
62 63 64
		t.FailNow()
	}

H
Helin Wang 已提交
65 66
	// don't compare content, since it's already changed by
	// gradient update.
67
	param1.Content = nil
68 69
	p.Content = nil

70
	if !reflect.DeepEqual(param1, p) {
71 72 73 74 75 76 77
		t.FailNow()
	}
}

func TestMultipleInit(t *testing.T) {
	s := pserver.NewService()
	var dummy int
78
	err := s.FinishInitParams(0, &dummy)
79 80 81 82 83
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, &dummy)
84
	if err.Error() != pserver.AlreadyInitialized {
85 86 87 88
		t.FailNow()
	}
}

89 90 91
func TestUninitialized(t *testing.T) {
	s := pserver.NewService()
	var dummy int
92
	err := s.SendGrad(pserver.Gradient{}, &dummy)
93
	if err.Error() != pserver.Uninitialized {
94 95 96 97
		t.FailNow()
	}
}

98 99
func TestBlockUntilInitialized(t *testing.T) {
	s := pserver.NewService()
100
	ch := make(chan struct{}, 2)
101 102 103
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
104 105
		var param pserver.Parameter
		err := s.GetParam("param_a", &param)
106 107 108 109
		if err != nil {
			t.FailNow()
		}
		wg.Done()
H
Helin Wang 已提交
110
		ch <- struct{}{}
111 112 113 114 115
	}()

	wg.Add(1)
	go func() {
		var dummy int
H
Helin Wang 已提交
116
		err := s.Save("", &dummy)
117 118 119 120
		if err != nil {
			t.FailNow()
		}
		wg.Done()
H
Helin Wang 已提交
121
		ch <- struct{}{}
122 123
	}()

124
	time.Sleep(50 * time.Millisecond)
125

H
Helin Wang 已提交
126 127 128 129 130 131 132
	select {
	case <-ch:
		// some function returned before initialization is completed.
		t.FailNow()
	default:
	}

133 134 135 136 137
	var p pserver.Parameter
	p.Name = "param_a"
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
	p.ElementType = pserver.Int32
	var dummy int
138
	err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, &dummy)
139 140 141 142
	if err != nil {
		t.FailNow()
	}

143 144 145 146 147 148 149
	err = s.FinishInitParams(0, &dummy)
	if err != nil {
		t.FailNow()
	}

	wg.Wait()
}