service_test.go 3.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
package pserver_test

import (
	"reflect"
	"sync"
	"testing"

	"github.com/PaddlePaddle/Paddle/paddle/go/pserver"
)

func TestFull(t *testing.T) {
	s := pserver.NewService()
	var dummy int
	err := s.BeginInitParams(nil, &dummy)
	if err != nil {
		t.FailNow()
	}

	var p pserver.Parameter
	p.Name = "param_a"
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
	p.ElementType = pserver.Int32
	err = s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy)
	if err != nil {
		t.FailNow()
	}

	var p1 pserver.Parameter
	p1.Name = "param_b"
	p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
	p1.ElementType = pserver.Float32
	err = s.InitParam(pserver.ParameterWithConfig{p1, nil}, &dummy)
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, &dummy)
	if err != nil {
		t.FailNow()
	}

	var params []pserver.Parameter
	err = s.GetParams([]string{"param_b", "param_a"}, &params)
	if err != nil {
		t.FailNow()
	}

	if len(params) != 2 || !reflect.DeepEqual(params[0], p1) || !reflect.DeepEqual(params[0], p1) {
		t.FailNow()
	}

	grads := []pserver.Gradient{pserver.Gradient(p1), pserver.Gradient(p)}
	err = s.SendGrads(grads, &dummy)
	if err != nil {
		t.FailNow()
	}

	var params1 []pserver.Parameter
	err = s.GetParams([]string{"param_b", "param_a"}, &params1)
	if err != nil {
		t.FailNow()
	}

	if len(params) != 2 {
		t.FailNow()
	}

H
Helin Wang 已提交
68 69
	// don't compare content, since it's already changed by
	// gradient update.
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
	params1[0].Content = nil
	params1[0].Content = nil
	p.Content = nil
	p1.Content = nil

	if !reflect.DeepEqual(params1[0], p1) || !reflect.DeepEqual(params1[0], p1) {
		t.FailNow()
	}
}

func TestMultipleInit(t *testing.T) {
	s := pserver.NewService()
	var dummy int
	err := s.BeginInitParams(nil, &dummy)
	if err != nil {
		t.FailNow()
	}

	// this is fine, it's possible for client to call init
	// multiple times.
	err = s.BeginInitParams(nil, &dummy)
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, &dummy)
	if err != nil {
		t.FailNow()
	}

	err = s.FinishInitParams(0, &dummy)
H
Helin Wang 已提交
101
	if err != pserver.ErrAlreadyInitialized {
102 103 104 105
		t.FailNow()
	}

	err = s.BeginInitParams(nil, &dummy)
H
Helin Wang 已提交
106
	if err != pserver.ErrAlreadyInitialized {
107 108 109 110
		t.FailNow()
	}
}

111 112 113 114 115 116 117 118 119
func TestUninitialized(t *testing.T) {
	s := pserver.NewService()
	var dummy int
	err := s.SendGrads(nil, &dummy)
	if err != pserver.ErrUninitialized {
		t.FailNow()
	}
}

120 121
func TestBlockUntilInitialized(t *testing.T) {
	s := pserver.NewService()
122
	ch := make(chan struct{}, 2)
123 124 125 126 127 128 129 130 131
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
		var params []pserver.Parameter
		err := s.GetParams(nil, &params)
		if err != nil {
			t.FailNow()
		}
		wg.Done()
H
Helin Wang 已提交
132
		ch <- struct{}{}
133 134 135 136 137
	}()

	wg.Add(1)
	go func() {
		var dummy int
H
Helin Wang 已提交
138
		err := s.Save("", &dummy)
139 140 141 142
		if err != nil {
			t.FailNow()
		}
		wg.Done()
H
Helin Wang 已提交
143
		ch <- struct{}{}
144 145 146 147 148 149 150 151
	}()

	var dummy int
	err := s.BeginInitParams(nil, &dummy)
	if err != nil {
		t.FailNow()
	}

H
Helin Wang 已提交
152 153 154 155 156 157 158
	select {
	case <-ch:
		// some function returned before initialization is completed.
		t.FailNow()
	default:
	}

159 160 161 162 163 164 165
	err = s.FinishInitParams(0, &dummy)
	if err != nil {
		t.FailNow()
	}

	wg.Wait()
}