etcd_client_test.go 1.7 KB
Newer Older
1 2 3 4
package client_test

import (
	"io/ioutil"
5
	"net/url"
6
	"os"
7
	"strings"
8 9 10 11 12 13 14 15 16 17 18 19 20
	"sync"
	"testing"

	"github.com/PaddlePaddle/Paddle/go/pserver/client"
	"github.com/coreos/etcd/embed"
)

func TestSelector(t *testing.T) {
	etcdDir, err := ioutil.TempDir("", "")
	if err != nil {
		t.Fatal(err)
	}
	cfg := embed.NewConfig()
21 22 23 24
	lpurl, _ := url.Parse("http://localhost:0")
	lcurl, _ := url.Parse("http://localhost:0")
	cfg.LPUrls = []url.URL{*lpurl}
	cfg.LCUrls = []url.URL{*lcurl}
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
	cfg.Dir = etcdDir
	e, err := embed.StartEtcd(cfg)
	if err != nil {
		t.Fatal(err)
	}

	defer func() {
		e.Close()
		if err := os.RemoveAll(etcdDir); err != nil {
			t.Fatal(err)
		}
	}()

	<-e.Server.ReadyNotify()

40 41 42
	port := strings.Split(e.Clients[0].Addr().String(), ":")[1]
	endpoint := "127.0.0.1:" + port

43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
	var mu sync.Mutex
	selectedCount := 0
	var wg sync.WaitGroup
	selectAndDone := func(c *client.Etcd) {
		defer wg.Done()

		selected, err := c.Select()
		if err != nil {
			panic(err)
		}

		if selected {
			mu.Lock()
			selectedCount++
			mu.Unlock()
			err = c.Done()
			if err != nil {
				t.Fatal(err)
			}
		}
	}

65 66 67 68
	c0 := client.NewEtcd(endpoint)
	c1 := client.NewEtcd(endpoint)
	c2 := client.NewEtcd(endpoint)
	c3 := client.NewEtcd(endpoint)
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
	wg.Add(3)
	go selectAndDone(c0)
	go selectAndDone(c1)
	go selectAndDone(c2)
	wg.Wait()

	// simulate trainer crashed and restarted after the
	// initialization process.
	wg.Add(1)
	go selectAndDone(c3)
	wg.Wait()

	mu.Lock()
	if selectedCount != 1 {
		t.Fatal("selected count wrong:", selectedCount)
	}
	mu.Unlock()

	err = c0.Close()
	if err != nil {
		t.Fatal(err)
	}

	err = c1.Close()
	if err != nil {
		t.Fatal(err)
	}

	err = c2.Close()
	if err != nil {
		t.Fatal(err)
	}

	err = c3.Close()
	if err != nil {
		t.Fatal(err)
	}
}