client_test.go 5.7 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Q
Qiao Longfei 已提交
15
package client_test
16 17

import (
Q
Qiao Longfei 已提交
18
	"context"
D
dongzhihong 已提交
19
	"io/ioutil"
20
	"math/rand"
21 22 23 24 25
	"net"
	"net/http"
	"net/rpc"
	"strconv"
	"strings"
26
	"sync"
27
	"testing"
Q
Qiao Longfei 已提交
28
	"time"
29

H
Helin Wang 已提交
30
	"github.com/PaddlePaddle/Paddle/go/pserver"
Q
Qiao Longfei 已提交
31 32
	"github.com/PaddlePaddle/Paddle/go/pserver/client"
	"github.com/coreos/etcd/clientv3"
33
	log "github.com/inconshreveable/log15"
34 35
)

Q
Qiao Longfei 已提交
36 37 38 39 40
const (
	numPserver    = 10
	etcdEndpoints = "127.0.0.1:2379"
	timeout       = 2 * time.Second
)
41

Q
Qiao Longfei 已提交
42
var pserverClientPorts [numPserver]int
43

Q
Qiao Longfei 已提交
44 45 46
// this function init pserver client and return their ports in an array.
func initClient() [numPserver]int {
	var ports [numPserver]int
47 48 49 50 51 52 53 54 55 56 57
	for i := 0; i < numPserver; i++ {
		l, err := net.Listen("tcp", ":0")
		if err != nil {
			panic(err)
		}

		ss := strings.Split(l.Addr().String(), ":")
		p, err := strconv.Atoi(ss[len(ss)-1])
		if err != nil {
			panic(err)
		}
Q
Qiao Longfei 已提交
58
		ports[i] = p
59 60

		go func(l net.Listener) {
G
fix bug  
gongweibao 已提交
61
			var cp pserver.Checkpoint
62
			s, err := pserver.NewService(0, time.Hour, "", nil, cp)
W
wuyi05 已提交
63 64 65
			if err != nil {
				panic(err)
			}
66
			server := rpc.NewServer()
W
wuyi05 已提交
67
			err = server.Register(s)
68 69 70 71 72 73 74 75 76 77 78 79
			if err != nil {
				panic(err)
			}

			mux := http.NewServeMux()
			mux.Handle(rpc.DefaultRPCPath, server)
			err = http.Serve(l, mux)
			if err != nil {
				panic(err)
			}
		}(l)
	}
Q
Qiao Longfei 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92
	return ports
}

func initNativeClient() {
	pserverClientPorts = initClient()
}

func initEtcdClient() {
	client, err := clientv3.New(clientv3.Config{
		Endpoints:   []string{etcdEndpoints},
		DialTimeout: time.Second * time.Duration(1),
	})
	if err != nil {
93
		log.Error("error init etcd client", log.Ctx{"error": err})
Q
Qiao Longfei 已提交
94 95
	}
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
H
Helin Wang 已提交
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
	_, err = client.Delete(ctx, pserver.PsDesired)
	if err != nil {
		panic(err)
	}

	_, err = client.Delete(ctx, pserver.PsPath)
	if err != nil {
		panic(err)
	}

	_, err = client.Put(ctx, pserver.PsDesired, strconv.Itoa(numPserver))
	if err != nil {
		panic(err)
	}

Q
Qiao Longfei 已提交
111 112
	ports := initClient()
	for i := 0; i < numPserver; i++ {
H
Helin Wang 已提交
113 114 115 116
		_, err = client.Put(ctx, pserver.PsPath+strconv.Itoa(i), ":"+strconv.Itoa(ports[i]))
		if err != nil {
			panic(err)
		}
Q
Qiao Longfei 已提交
117 118
	}
	cancel()
H
Helin Wang 已提交
119 120 121 122
	err = client.Close()
	if err != nil {
		panic(err)
	}
123 124 125 126
}

type selector bool

127 128 129 130 131 132
func (s selector) Select() (bool, error) {
	return bool(s), nil
}

func (s selector) Done() error {
	return nil
133 134
}

Q
Qiao Longfei 已提交
135
type lister []client.Server
136

Q
Qiao Longfei 已提交
137
func (l lister) List() []client.Server {
138 139 140
	return l
}

141
func testClient(t *testing.T, c *client.Client) {
142 143 144 145 146
	selected, err := c.BeginInitParams()
	if err != nil {
		t.Fatal(err)
	}

147 148 149 150
	if !selected {
		t.Fatal("should be selected.")
	}

151
	const numParameter = 1000
Q
Qiao Longfei 已提交
152
	config, err := ioutil.ReadFile("./c/test/testdata/optimizer.pb")
D
dongzhihong 已提交
153 154 155
	if err != nil {
		t.Fatalf("read optimizer proto failed")
	}
156 157

	var wg sync.WaitGroup
158
	for i := 0; i < numParameter; i++ {
159 160 161 162 163 164 165 166 167 168 169 170
		wg.Add(1)
		go func(i int) {
			var p pserver.Parameter
			p.Name = "p_" + strconv.Itoa(i)
			p.ElementType = pserver.Float32
			p.Content = make([]byte, (i+1)*100)
			err := c.InitParam(pserver.ParameterWithConfig{Param: p, Config: config})
			if err != nil {
				t.Fatal(err)
			}
			wg.Done()
		}(i)
171
	}
172
	wg.Wait()
173

D
dongzhihong 已提交
174
	err = c.FinishInitParams()
175 176 177 178 179
	if err != nil {
		t.Fatal(err)
	}

	var grads []pserver.Gradient
180
	for i := 0; i < numParameter; i++ {
181 182 183
		var g pserver.Gradient
		g.Name = "p_" + strconv.Itoa(i)
		g.ElementType = pserver.Float32
D
dongzhihong 已提交
184
		g.Content = make([]byte, (i+1)*100)
185 186 187
		grads = append(grads, g)
	}

188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
	const paramPerGroup = 10
	const numGroups = numParameter / paramPerGroup

	// shuffle send grads order
	for i := range grads {
		j := rand.Intn(i + 1)
		grads[i], grads[j] = grads[j], grads[i]
	}

	for i := 0; i < numGroups; i++ {
		var gs []pserver.Gradient
		if i == numGroups-1 {
			gs = grads[i*paramPerGroup:]
		} else {
			gs = grads[i*paramPerGroup : (i+1)*paramPerGroup]
		}

		wg.Add(1)
		go func(gs []pserver.Gradient) {
H
Helin Wang 已提交
207
			err := c.SendGrads(gs)
208 209 210 211 212
			if err != nil {
				t.Fatal(err)
			}
			wg.Done()
		}(gs)
213 214 215 216 217 218 219
	}

	names := make([]string, numParameter)
	for i := 0; i < numParameter; i++ {
		names[i] = "p_" + strconv.Itoa(i)
	}

220 221 222 223 224 225 226
	for i := 0; i < numGroups; i++ {
		var ns []string
		if i == numGroups-1 {
			ns = names[i*paramPerGroup:]
		} else {
			ns = names[i*paramPerGroup : (i+1)*paramPerGroup]
		}
227

228 229 230 231 232 233
		wg.Add(1)
		go func(ns []string) {
			params, err := c.GetParams(ns)
			if err != nil {
				t.Fatal(err)
			}
234

235 236 237 238 239 240 241 242 243 244 245
			if len(ns) != len(params) {
				t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params))
			}

			for i := range params {
				if ns[i] != params[i].Name {
					t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", ns[i], params[i].Name)
				}
			}
			wg.Done()
		}(ns)
246
	}
247 248

	wg.Wait()
249
}
Q
Qiao Longfei 已提交
250 251 252 253 254 255 256 257

func TestNativeClient(t *testing.T) {
	initNativeClient()
	servers := make([]client.Server, numPserver)
	for i := 0; i < numPserver; i++ {
		servers[i] = client.Server{Index: i, Addr: ":" + strconv.Itoa(pserverClientPorts[i])}
	}
	c1 := client.NewClient(lister(servers), len(servers), selector(true))
258
	testClient(t, c1)
Q
Qiao Longfei 已提交
259 260
}

261 262
// EtcdClient is a disabled test, since we have not embedded etcd into
// our test.
Q
Qiao Longfei 已提交
263 264
func EtcdClient(t *testing.T) {
	initEtcdClient()
G
fix bug  
gongweibao 已提交
265 266
	etcdClient := client.NewEtcd(etcdEndpoints)
	c2 := client.NewClient(etcdClient, etcdClient.Desired(), selector(true))
267
	testClient(t, c2)
Q
Qiao Longfei 已提交
268
}