service_test.go 3.9 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17
package pserver_test

import (
18
	"io/ioutil"
D
dongzhihong 已提交
19
	"reflect"
20 21
	"sync"
	"testing"
22
	"time"
23

24
	"github.com/PaddlePaddle/Paddle/go/pserver"
25 26
)

Q
Qiao Longfei 已提交
27 28 29 30
const (
	OptimizerConfig = "./client/c/test/testdata/optimizer.pb"
)

D
dongzhihong 已提交
31
func TestServiceFull(t *testing.T) {
D
dongzhihong 已提交
32
	var cp pserver.Checkpoint
H
Helin Wang 已提交
33
	s, err := pserver.NewService(0, time.Second, "", nil, cp)
W
wuyi05 已提交
34 35 36
	if err != nil {
		t.Error(err)
	}
37 38
	var p pserver.Parameter
	p.Name = "param_a"
D
dongzhihong 已提交
39
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
40
	p.ElementType = pserver.Int32
Q
Qiao Longfei 已提交
41
	config, err := ioutil.ReadFile(OptimizerConfig)
42
	if err != nil {
43
		t.Fatalf("read optimizer proto failed")
44 45
	}

46
	err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
47
	if err != nil {
48
		t.Fatal(err)
49 50
	}

D
dongzhihong 已提交
51 52 53 54 55 56
	var p1 pserver.Parameter
	p1.Name = "param_b"
	p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
	p1.ElementType = pserver.Float32
	err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil)
	if err != nil {
57
		t.Fatal(err)
D
dongzhihong 已提交
58 59 60 61
	}

	err = s.FinishInitParams(0, nil)
	if err != nil {
62
		t.Fatal(err)
D
dongzhihong 已提交
63 64 65 66 67
	}

	var param pserver.Parameter
	err = s.GetParam("param_b", &param)
	if err != nil {
68
		t.Fatal(err)
D
dongzhihong 已提交
69 70 71
	}

	if !reflect.DeepEqual(param, p1) {
72
		t.Fatal("not equal:", param, p1)
D
dongzhihong 已提交
73 74 75 76 77 78
	}

	g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)

	err = s.SendGrad(g1, nil)
	if err != nil {
79
		t.Fatal(err)
D
dongzhihong 已提交
80 81 82 83
	}
	err = s.SendGrad(g2, nil)

	if err != nil {
84
		t.Fatal(err)
D
dongzhihong 已提交
85 86 87 88 89
	}

	var param1 pserver.Parameter
	err = s.GetParam("param_a", &param1)
	if err != nil {
90
		t.Fatal(err)
D
dongzhihong 已提交
91 92 93 94 95 96 97 98
	}

	// don't compare content, since it's already changed by
	// gradient update.
	param1.Content = nil
	p.Content = nil

	if !reflect.DeepEqual(param1, p) {
99
		t.Fatal("not equal:", param1, p)
D
dongzhihong 已提交
100 101 102 103
	}
}

func TestMultipleInit(t *testing.T) {
D
dongzhihong 已提交
104
	var cp pserver.Checkpoint
H
Helin Wang 已提交
105
	s, err := pserver.NewService(0, time.Second, "", nil, cp)
W
wuyi05 已提交
106
	if err != nil {
107
		t.Fatal(err)
W
wuyi05 已提交
108 109
	}
	err = s.FinishInitParams(0, nil)
D
dongzhihong 已提交
110
	if err != nil {
111
		t.Fatal(err)
D
dongzhihong 已提交
112 113 114 115
	}

	err = s.FinishInitParams(0, nil)
	if err.Error() != pserver.AlreadyInitialized {
116
		t.Fatal(err)
D
dongzhihong 已提交
117
	}
118 119
}

120
func TestUninitialized(t *testing.T) {
D
dongzhihong 已提交
121
	var cp pserver.Checkpoint
H
Helin Wang 已提交
122
	s, err := pserver.NewService(0, time.Second, "", nil, cp)
W
wuyi05 已提交
123
	err = s.SendGrad(pserver.Gradient{}, nil)
H
Helin Wang 已提交
124
	if err.Error() != pserver.Uninitialized {
125
		t.Fatal(err)
126 127 128
	}
}

129
func TestBlockUntilInitialized(t *testing.T) {
D
dongzhihong 已提交
130
	var cp pserver.Checkpoint
H
Helin Wang 已提交
131
	s, err := pserver.NewService(0, time.Second, "", nil, cp)
W
wuyi05 已提交
132 133 134
	if err != nil {
		t.Error(err)
	}
135
	ch := make(chan struct{}, 2)
H
Helin Wang 已提交
136
	errCh := make(chan error, 2)
137 138 139
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
140 141
		var param pserver.Parameter
		err := s.GetParam("param_a", &param)
142
		if err != nil {
H
Helin Wang 已提交
143
			errCh <- err
144 145
		}
		wg.Done()
H
Helin Wang 已提交
146
		ch <- struct{}{}
147 148
	}()

149
	time.Sleep(50 * time.Millisecond)
150

H
Helin Wang 已提交
151 152 153 154
	select {
	case <-ch:
		// some function returned before initialization is completed.
		t.FailNow()
H
Helin Wang 已提交
155 156
	case <-errCh:
		t.FailNow()
H
Helin Wang 已提交
157 158 159
	default:
	}

160 161 162 163
	var p pserver.Parameter
	p.Name = "param_a"
	p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
	p.ElementType = pserver.Int32
Q
Qiao Longfei 已提交
164
	config, err := ioutil.ReadFile(OptimizerConfig)
165 166 167 168
	if err != nil {
		t.Fatalf("read optimizer proto failed")
	}
	err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: config}, nil)
169

170
	if err != nil {
171
		t.Fatal(err)
172 173
	}

174
	err = s.FinishInitParams(0, nil)
175
	if err != nil {
176
		t.Fatal(err)
177 178 179 180
	}

	wg.Wait()
}
D
dongzhihong 已提交
181 182

func TestCheckpointSpeed(t *testing.T) {
D
dzhwinter 已提交
183
	//TODO(zhihong): test speed
D
dongzhihong 已提交
184
}