client_test.go 3.2 KB
Newer Older
1
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
D
dongzhihong 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17 18 19 20 21
package master_test

import (
	"fmt"
	"net"
	"net/http"
	"net/rpc"
22
	"os"
23
	"runtime"
24 25
	"strconv"
	"strings"
26
	"sync"
27 28 29 30
	"testing"
	"time"

	"github.com/PaddlePaddle/Paddle/go/master"
31
	"github.com/PaddlePaddle/recordio"
32 33
)

34 35 36 37 38 39 40 41 42 43 44 45
// tool function for testing output goroutine ids
func goid() int {
	var buf [64]byte
	n := runtime.Stack(buf[:], false)
	idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0]
	id, err := strconv.Atoi(idField)
	if err != nil {
		panic(fmt.Sprintf("cannot get goroutine id: %v", err))
	}
	return id
}

H
Helin Wang 已提交
46 47 48 49 50
func TestNextRecord(t *testing.T) {
	const (
		path  = "/tmp/master_client_TestFull"
		total = 50
	)
51 52 53 54 55 56 57 58 59 60 61
	l, err := net.Listen("tcp", ":0")
	if err != nil {
		panic(err)
	}

	ss := strings.Split(l.Addr().String(), ":")
	p, err := strconv.Atoi(ss[len(ss)-1])
	if err != nil {
		panic(err)
	}
	go func(l net.Listener) {
62
		s, err := master.NewService(&master.InMemStore{}, 1, time.Second*60, 1)
63 64 65 66
		if err != nil {
			panic(err)
		}

67
		server := rpc.NewServer()
68
		err = server.Register(s)
69 70 71 72 73 74 75 76 77 78 79 80
		if err != nil {
			panic(err)
		}

		mux := http.NewServeMux()
		mux.Handle(rpc.DefaultRPCPath, server)
		err = http.Serve(l, mux)
		if err != nil {
			panic(err)
		}
	}(l)

H
Helin Wang 已提交
81
	f, err := os.Create(path)
82 83 84 85
	if err != nil {
		panic(err)
	}

86
	w := recordio.NewWriter(f, 1, -1)
H
Helin Wang 已提交
87
	for i := 0; i < total; i++ {
H
Helin Wang 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101
		_, err = w.Write([]byte{byte(i)})
		if err != nil {
			panic(err)
		}
	}

	err = w.Close()
	if err != nil {
		panic(err)
	}

	err = f.Close()
	if err != nil {
		panic(err)
102
	}
H
Helin Wang 已提交
103

104 105 106 107 108 109 110 111 112 113 114
	// start several client to test task fetching
	var wg sync.WaitGroup
	for i := 0; i < 4; i++ {
		wg.Add(1)
		// test for multiple concurrent clients
		go func() {
			defer wg.Done()
			// each go-routine needs a single client connection instance
			c, e := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(1))
			if e != nil {
				t.Fatal(e)
G
gongweibao 已提交
115
			}
116 117 118
			e = c.SetDataset([]string{path})
			if e != nil {
				panic(e)
119
			}
G
gongweibao 已提交
120

121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
			// test for n passes
			for pass := 0; pass < 10; pass++ {
				c.StartGetRecords(pass)

				received := make(map[byte]bool)
				taskid := 0
				for {
					r, e := c.NextRecord()
					if e != nil {
						// ErrorPassAfter will wait, else break for next pass
						if e.Error() == master.ErrPassBefore.Error() ||
							e.Error() == master.ErrNoMoreAvailable.Error() {
							break
						}
						t.Fatal(pass, taskid, "Read error:", e)
					}
					if len(r) != 1 {
						t.Fatal(pass, taskid, "Length should be 1.", r)
					}
					if received[r[0]] {
						t.Fatal(pass, taskid, "Received duplicate.", received, r)
					}
					taskid++
					received[r[0]] = true
				}
146
			}
147
		}()
148
	}
149
	wg.Wait()
150
}