recordio_test.go 1.6 KB
Newer Older
1 2 3 4 5 6 7
package recordio_test

import (
	"bytes"
	"reflect"
	"testing"

8
	"github.com/PaddlePaddle/Paddle/go/recordio"
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
)

func TestWriteRead(t *testing.T) {
	const total = 1000
	var buf bytes.Buffer
	w := recordio.NewWriter(&buf, 0, -1)
	for i := 0; i < total; i++ {
		_, err := w.Write(make([]byte, i))
		if err != nil {
			t.Fatal(err)
		}
	}
	w.Close()

	idx, err := recordio.LoadIndex(bytes.NewReader(buf.Bytes()))
	if err != nil {
		t.Fatal(err)
	}

	if idx.NumRecords() != total {
		t.Fatal("num record does not match:", idx.NumRecords(), total)
	}

H
Helin Wang 已提交
32
	s := recordio.NewRangeScanner(bytes.NewReader(buf.Bytes()), idx, -1, -1)
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
	i := 0
	for s.Scan() {
		if !reflect.DeepEqual(s.Record(), make([]byte, i)) {
			t.Fatal("not equal:", len(s.Record()), len(make([]byte, i)))
		}
		i++
	}

	if i != total {
		t.Fatal("total count not match:", i, total)
	}
}

func TestChunkIndex(t *testing.T) {
	const total = 1000
	var buf bytes.Buffer
	w := recordio.NewWriter(&buf, 0, -1)
	for i := 0; i < total; i++ {
		_, err := w.Write(make([]byte, i))
		if err != nil {
			t.Fatal(err)
		}
	}
	w.Close()

	idx, err := recordio.LoadIndex(bytes.NewReader(buf.Bytes()))
	if err != nil {
		t.Fatal(err)
	}

	if idx.NumChunks() != total {
		t.Fatal("unexpected chunk num:", idx.NumChunks(), total)
	}

	for i := 0; i < total; i++ {
		newIdx := idx.ChunkIndex(i)
H
Helin Wang 已提交
69
		s := recordio.NewRangeScanner(bytes.NewReader(buf.Bytes()), newIdx, -1, -1)
70 71 72 73 74 75 76 77 78 79 80 81
		j := 0
		for s.Scan() {
			if !reflect.DeepEqual(s.Record(), make([]byte, i)) {
				t.Fatal("not equal:", len(s.Record()), len(make([]byte, i)))
			}
			j++
		}
		if j != 1 {
			t.Fatal("unexpected record per chunk:", j)
		}
	}
}