client_test.go 3.9 KB
Newer Older
Q
Qiao Longfei 已提交
1
package client_test
2 3

import (
Q
Qiao Longfei 已提交
4
	"context"
D
dongzhihong 已提交
5
	"io/ioutil"
6 7 8 9 10 11
	"net"
	"net/http"
	"net/rpc"
	"strconv"
	"strings"
	"testing"
Q
Qiao Longfei 已提交
12
	"time"
13

H
Helin Wang 已提交
14
	"github.com/PaddlePaddle/Paddle/go/pserver"
Q
Qiao Longfei 已提交
15 16 17
	"github.com/PaddlePaddle/Paddle/go/pserver/client"
	"github.com/coreos/etcd/clientv3"
	log "github.com/sirupsen/logrus"
18 19
)

Q
Qiao Longfei 已提交
20 21 22 23 24
const (
	numPserver    = 10
	etcdEndpoints = "127.0.0.1:2379"
	timeout       = 2 * time.Second
)
25

Q
Qiao Longfei 已提交
26
var pserverClientPorts [numPserver]int
27

Q
Qiao Longfei 已提交
28 29 30
// this function init pserver client and return their ports in an array.
func initClient() [numPserver]int {
	var ports [numPserver]int
31 32 33 34 35 36 37 38 39 40 41
	for i := 0; i < numPserver; i++ {
		l, err := net.Listen("tcp", ":0")
		if err != nil {
			panic(err)
		}

		ss := strings.Split(l.Addr().String(), ":")
		p, err := strconv.Atoi(ss[len(ss)-1])
		if err != nil {
			panic(err)
		}
Q
Qiao Longfei 已提交
42
		ports[i] = p
43 44

		go func(l net.Listener) {
45
			s, err := pserver.NewService(0)
W
wuyi05 已提交
46 47 48
			if err != nil {
				panic(err)
			}
49
			server := rpc.NewServer()
W
wuyi05 已提交
50
			err = server.Register(s)
51 52 53 54 55 56 57 58 59 60 61 62
			if err != nil {
				panic(err)
			}

			mux := http.NewServeMux()
			mux.Handle(rpc.DefaultRPCPath, server)
			err = http.Serve(l, mux)
			if err != nil {
				panic(err)
			}
		}(l)
	}
Q
Qiao Longfei 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
	return ports
}

func initNativeClient() {
	pserverClientPorts = initClient()
}

func initEtcdClient() {
	client, err := clientv3.New(clientv3.Config{
		Endpoints:   []string{etcdEndpoints},
		DialTimeout: time.Second * time.Duration(1),
	})
	if err != nil {
		log.Errorf("err %v", err)
	}
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
	client.Delete(ctx, pserver.PsDesired)
	client.Delete(ctx, pserver.PsPath)
	client.Put(ctx, pserver.PsDesired, strconv.Itoa(numPserver))
	ports := initClient()
	for i := 0; i < numPserver; i++ {
		client.Put(ctx, pserver.PsPath+strconv.Itoa(i), ":"+strconv.Itoa(ports[i]))
	}
	cancel()
	client.Close()
88 89 90 91 92 93 94 95
}

type selector bool

func (s selector) Select() bool {
	return bool(s)
}

Q
Qiao Longfei 已提交
96
type lister []client.Server
97

Q
Qiao Longfei 已提交
98
func (l lister) List() []client.Server {
99 100 101
	return l
}

Q
Qiao Longfei 已提交
102
func ClientTest(t *testing.T, c *client.Client) {
103 104 105 106 107 108
	selected := c.BeginInitParams()
	if !selected {
		t.Fatal("should be selected.")
	}

	const numParameter = 100
Q
Qiao Longfei 已提交
109
	config, err := ioutil.ReadFile("./c/test/testdata/optimizer.pb")
D
dongzhihong 已提交
110 111 112
	if err != nil {
		t.Fatalf("read optimizer proto failed")
	}
113 114 115 116
	for i := 0; i < numParameter; i++ {
		var p pserver.Parameter
		p.Name = "p_" + strconv.Itoa(i)
		p.ElementType = pserver.Float32
D
dongzhihong 已提交
117
		p.Content = make([]byte, (i+1)*100)
D
dongzhihong 已提交
118
		err := c.InitParam(pserver.ParameterWithConfig{Param: p, Config: config})
119 120 121 122 123
		if err != nil {
			t.Fatal(err)
		}
	}

D
dongzhihong 已提交
124
	err = c.FinishInitParams()
125 126 127 128 129 130 131 132 133
	if err != nil {
		t.Fatal(err)
	}

	var grads []pserver.Gradient
	for i := 0; i < numParameter/2; i++ {
		var g pserver.Gradient
		g.Name = "p_" + strconv.Itoa(i)
		g.ElementType = pserver.Float32
D
dongzhihong 已提交
134
		g.Content = make([]byte, (i+1)*100)
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
		grads = append(grads, g)
	}

	err = c.SendGrads(grads)
	if err != nil {
		t.Fatal(err)
	}

	names := make([]string, numParameter)
	for i := 0; i < numParameter; i++ {
		names[i] = "p_" + strconv.Itoa(i)
	}

	params, err := c.GetParams(names)
	if err != nil {
		t.Fatal(err)
	}

	if len(names) != len(params) {
		t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params))
	}

	for i := range params {
		if names[i] != params[i].Name {
159
			t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i].Name)
160 161 162
		}
	}
}
Q
Qiao Longfei 已提交
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180

func TestNativeClient(t *testing.T) {
	initNativeClient()
	servers := make([]client.Server, numPserver)
	for i := 0; i < numPserver; i++ {
		servers[i] = client.Server{Index: i, Addr: ":" + strconv.Itoa(pserverClientPorts[i])}
	}
	c1 := client.NewClient(lister(servers), len(servers), selector(true))
	ClientTest(t, c1)
}

// TODO: tmperary disable etcdClient test for dependency of etcd)
func EtcdClient(t *testing.T) {
	initEtcdClient()
	etcd_client := client.NewEtcd(etcdEndpoints)
	c2 := client.NewClient(etcd_client, etcd_client.Desired(), selector(true))
	ClientTest(t, c2)
}