service_test.go 3.2 KB
Newer Older
1 2 3
package pserver_test

import (
4
	"io/ioutil"
5 6
	"sync"
	"testing"
7
	"time"
8

9
	"github.com/PaddlePaddle/Paddle/go/pserver"
10 11 12 13 14 15
)

func TestFull(t *testing.T) {
	s := pserver.NewService()
	var p pserver.Parameter
	p.Name = "param_a"
D
dongzhihong 已提交
16
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
17
	p.ElementType = pserver.Int32
18
	config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb.txt")
19
	if err != nil {
20
		t.Fatalf("read optimizer proto failed")
21 22
	}

23
	err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
24 25 26 27
	if err != nil {
		t.FailNow()
	}

28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
	// 	var p1 pserver.Parameter
	// 	p1.Name = "param_b"
	// 	p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
	// 	p1.ElementType = pserver.Float32
	// 	fmt.Println("paddle passed")
	// 	err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil)
	// 	if err != nil {
	// 		t.FailNow()
	// 	}

	// 	err = s.FinishInitParams(0, nil)
	// 	if err != nil {
	// 		t.FailNow()
	// 	}

	// 	var param pserver.Parameter
	// 	err = s.GetParam("param_b", &param)
	// 	if err != nil {
	// 		t.FailNow()
	// 	}

	// 	if !reflect.DeepEqual(param, p1) {
	// 		t.FailNow()
	// 	}

	// 	g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)
	// 	err = s.SendGrad(g1, nil)
	// 	if err != nil {
	// 		t.FailNow()
	// 	}
	// 	err = s.SendGrad(g2, nil)

	// 	if err != nil {
	// 		t.FailNow()
	// 	}

	// 	var param1 pserver.Parameter
	// 	err = s.GetParam("param_a", &param1)
	// 	if err != nil {
	// 		t.FailNow()
	// 	}

	// 	// don't compare content, since it's already changed by
	// 	// gradient update.
	// 	param1.Content = nil
	// 	p.Content = nil

	// 	if !reflect.DeepEqual(param1, p) {
	// 		t.FailNow()
	// 	}
	// }

	// func TestMultipleInit(t *testing.T) {
	// 	s := pserver.NewService()
	// 	err := s.FinishInitParams(0, nil)
	// 	if err != nil {
	// 		t.FailNow()
	// 	}

	// 	err = s.FinishInitParams(0, nil)
	// 	if err.Error() != pserver.AlreadyInitialized {
	// 		t.FailNow()
	// 	}
91 92
}

93 94
func TestUninitialized(t *testing.T) {
	s := pserver.NewService()
95
	err := s.SendGrad(pserver.Gradient{}, nil)
H
Helin Wang 已提交
96
	if err.Error() != pserver.Uninitialized {
97 98 99 100
		t.FailNow()
	}
}

101 102
func TestBlockUntilInitialized(t *testing.T) {
	s := pserver.NewService()
103
	ch := make(chan struct{}, 2)
H
Helin Wang 已提交
104
	errCh := make(chan error, 2)
105 106 107
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
108 109
		var param pserver.Parameter
		err := s.GetParam("param_a", &param)
110
		if err != nil {
H
Helin Wang 已提交
111
			errCh <- err
112 113
		}
		wg.Done()
H
Helin Wang 已提交
114
		ch <- struct{}{}
115 116 117 118
	}()

	wg.Add(1)
	go func() {
119
		err := s.Save("", nil)
120
		if err != nil {
H
Helin Wang 已提交
121
			errCh <- err
122 123
		}
		wg.Done()
H
Helin Wang 已提交
124
		ch <- struct{}{}
125 126
	}()

127
	time.Sleep(50 * time.Millisecond)
128

H
Helin Wang 已提交
129 130 131 132
	select {
	case <-ch:
		// some function returned before initialization is completed.
		t.FailNow()
H
Helin Wang 已提交
133 134
	case <-errCh:
		t.FailNow()
H
Helin Wang 已提交
135 136 137
	default:
	}

138 139 140 141
	var p pserver.Parameter
	p.Name = "param_a"
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
	p.ElementType = pserver.Int32
142 143 144 145 146
	config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb.txt")
	if err != nil {
		t.Fatalf("read optimizer proto failed")
	}
	err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
147 148 149 150
	if err != nil {
		t.FailNow()
	}

151
	err = s.FinishInitParams(0, nil)
152 153 154 155 156 157
	if err != nil {
		t.FailNow()
	}

	wg.Wait()
}