service_test.go 3.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
package pserver_test

import (
	"reflect"
	"sync"
	"testing"

	"github.com/PaddlePaddle/Paddle/paddle/go/pserver"
)

func TestFull(t *testing.T) {
	s := pserver.NewService()
	var dummy int
	err := s.BeginInitParams(nil, &dummy)
	if err != nil {
		t.FailNow()
	}

	var p pserver.Parameter
	p.Name = "param_a"
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
	p.ElementType = pserver.Int32
	err = s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy)
	if err != nil {
		t.FailNow()
	}

	var p1 pserver.Parameter
	p1.Name = "param_b"
	p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
	p1.ElementType = pserver.Float32
	err = s.InitParam(pserver.ParameterWithConfig{p1, nil}, &dummy)
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, &dummy)
	if err != nil {
		t.FailNow()
	}

	var params []pserver.Parameter
	err = s.GetParams([]string{"param_b", "param_a"}, &params)
	if err != nil {
		t.FailNow()
	}

	if len(params) != 2 || !reflect.DeepEqual(params[0], p1) || !reflect.DeepEqual(params[0], p1) {
		t.FailNow()
	}

	grads := []pserver.Gradient{pserver.Gradient(p1), pserver.Gradient(p)}
	err = s.SendGrads(grads, &dummy)
	if err != nil {
		t.FailNow()
	}

	var params1 []pserver.Parameter
	err = s.GetParams([]string{"param_b", "param_a"}, &params1)
	if err != nil {
		t.FailNow()
	}

	if len(params) != 2 {
		t.FailNow()
	}

H
Helin Wang 已提交
68 69
	// don't compare content, since it's already changed by
	// gradient update.
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
	params1[0].Content = nil
	params1[0].Content = nil
	p.Content = nil
	p1.Content = nil

	if !reflect.DeepEqual(params1[0], p1) || !reflect.DeepEqual(params1[0], p1) {
		t.FailNow()
	}
}

func TestMultipleInit(t *testing.T) {
	s := pserver.NewService()
	var dummy int
	err := s.BeginInitParams(nil, &dummy)
	if err != nil {
		t.FailNow()
	}

	// this is fine, it's possible for client to call init
	// multiple times.
	err = s.BeginInitParams(nil, &dummy)
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, &dummy)
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, &dummy)
H
Helin Wang 已提交
101
	if err != pserver.ErrAlreadyInitialized {
102 103 104 105
		t.FailNow()
	}

	err = s.BeginInitParams(nil, &dummy)
H
Helin Wang 已提交
106
	if err != pserver.ErrAlreadyInitialized {
107 108 109 110 111 112
		t.FailNow()
	}
}

func TestBlockUntilInitialized(t *testing.T) {
	s := pserver.NewService()
H
Helin Wang 已提交
113
	ch := make(chan struct{}, 3)
114 115 116 117 118 119 120 121 122
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
		var params []pserver.Parameter
		err := s.GetParams(nil, &params)
		if err != nil {
			t.FailNow()
		}
		wg.Done()
H
Helin Wang 已提交
123
		ch <- struct{}{}
124 125 126 127 128
	}()

	wg.Add(1)
	go func() {
		var dummy int
H
Helin Wang 已提交
129
		err := s.Save("", &dummy)
130 131 132 133
		if err != nil {
			t.FailNow()
		}
		wg.Done()
H
Helin Wang 已提交
134
		ch <- struct{}{}
135 136
	}()

H
Helin Wang 已提交
137 138 139 140 141 142 143 144
	wg.Add(1)
	go func() {
		var dummy int
		err := s.SendGrads(nil, &dummy)
		if err != nil {
			t.FailNow()
		}
		wg.Done()
H
Helin Wang 已提交
145
		ch <- struct{}{}
H
Helin Wang 已提交
146 147
	}()

148 149 150 151 152 153
	var dummy int
	err := s.BeginInitParams(nil, &dummy)
	if err != nil {
		t.FailNow()
	}

H
Helin Wang 已提交
154 155 156 157 158 159 160
	select {
	case <-ch:
		// some function returned before initialization is completed.
		t.FailNow()
	default:
	}

161 162 163 164 165 166 167
	err = s.FinishInitParams(0, &dummy)
	if err != nil {
		t.FailNow()
	}

	wg.Wait()
}