client_test.go 2.6 KB
Newer Older
1 2 3
package pserver_test

import (
D
dongzhihong 已提交
4
	"io/ioutil"
5 6 7 8 9 10 11
	"net"
	"net/http"
	"net/rpc"
	"strconv"
	"strings"
	"testing"

H
Helin Wang 已提交
12
	"github.com/PaddlePaddle/Paddle/go/pserver"
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
)

const numPserver = 10

var port [numPserver]int

func init() {
	for i := 0; i < numPserver; i++ {
		l, err := net.Listen("tcp", ":0")
		if err != nil {
			panic(err)
		}

		ss := strings.Split(l.Addr().String(), ":")
		p, err := strconv.Atoi(ss[len(ss)-1])
		if err != nil {
			panic(err)
		}
		port[i] = p

		go func(l net.Listener) {
34
			s, err := pserver.NewService(0)
W
wuyi05 已提交
35 36 37
			if err != nil {
				panic(err)
			}
38
			server := rpc.NewServer()
W
wuyi05 已提交
39
			err = server.Register(s)
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
			if err != nil {
				panic(err)
			}

			mux := http.NewServeMux()
			mux.Handle(rpc.DefaultRPCPath, server)
			err = http.Serve(l, mux)
			if err != nil {
				panic(err)
			}
		}(l)
	}
}

type selector bool

func (s selector) Select() bool {
	return bool(s)
}

type lister []pserver.Server

func (l lister) List() []pserver.Server {
	return l
}

func TestClientFull(t *testing.T) {
	servers := make([]pserver.Server, numPserver)
	for i := 0; i < numPserver; i++ {
		servers[i] = pserver.Server{Index: i, Addr: ":" + strconv.Itoa(port[i])}
	}
	c := pserver.NewClient(lister(servers), len(servers), selector(true))
	selected := c.BeginInitParams()
	if !selected {
		t.Fatal("should be selected.")
	}

	const numParameter = 100
D
dongzhihong 已提交
78
	config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb")
D
dongzhihong 已提交
79 80 81
	if err != nil {
		t.Fatalf("read optimizer proto failed")
	}
82 83 84 85
	for i := 0; i < numParameter; i++ {
		var p pserver.Parameter
		p.Name = "p_" + strconv.Itoa(i)
		p.ElementType = pserver.Float32
D
dongzhihong 已提交
86
		p.Content = make([]byte, (i+1)*100)
D
dongzhihong 已提交
87
		err := c.InitParam(pserver.ParameterWithConfig{Param: p, Config: config})
88 89 90 91 92
		if err != nil {
			t.Fatal(err)
		}
	}

D
dongzhihong 已提交
93
	err = c.FinishInitParams()
94 95 96 97 98 99 100 101 102
	if err != nil {
		t.Fatal(err)
	}

	var grads []pserver.Gradient
	for i := 0; i < numParameter/2; i++ {
		var g pserver.Gradient
		g.Name = "p_" + strconv.Itoa(i)
		g.ElementType = pserver.Float32
D
dongzhihong 已提交
103
		g.Content = make([]byte, (i+1)*100)
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
		grads = append(grads, g)
	}

	err = c.SendGrads(grads)
	if err != nil {
		t.Fatal(err)
	}

	names := make([]string, numParameter)
	for i := 0; i < numParameter; i++ {
		names[i] = "p_" + strconv.Itoa(i)
	}

	params, err := c.GetParams(names)
	if err != nil {
		t.Fatal(err)
	}

	if len(names) != len(params) {
		t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params))
	}

	for i := range params {
		if names[i] != params[i].Name {
128
			t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i].Name)
129 130 131
		}
	}
}