crecordio.go 3.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
package main

/*
#include <string.h>

typedef int reader;
typedef int writer;
*/
import "C"

import (
	"io"
	"log"
	"os"
	"path/filepath"
	"strings"
	"unsafe"

	"github.com/PaddlePaddle/Paddle/paddle/go/recordio"
)

var nullPtr = unsafe.Pointer(uintptr(0))

type writer struct {
	w *recordio.Writer
	f *os.File
}

type reader struct {
	buffer chan []byte
	cancel chan struct{}
}

func read(paths []string, buffer chan<- []byte, cancel chan struct{}) {
	var curFile *os.File
	var curScanner *recordio.Scanner
	var pathIdx int

	var nextFile func() bool
	nextFile = func() bool {
		if pathIdx >= len(paths) {
			return false
		}

		path := paths[pathIdx]
		pathIdx++
		f, err := os.Open(path)
		if err != nil {
			return nextFile()
		}

		idx, err := recordio.LoadIndex(f)
		if err != nil {
			log.Println(err)
			err = f.Close()
			if err != nil {
				log.Println(err)
			}

			return nextFile()
		}

		curFile = f
		curScanner = recordio.NewScanner(f, idx, 0, -1)
		return true
	}

	more := nextFile()
	if !more {
		close(buffer)
		return
	}

	closeFile := func() {
		err := curFile.Close()
		if err != nil {
			log.Println(err)
		}
		curFile = nil
	}

	for {
		for curScanner.Scan() {
			select {
			case buffer <- curScanner.Record():
			case <-cancel:
				close(buffer)
				closeFile()
				return
			}
		}

		if err := curScanner.Error(); err != nil && err != io.EOF {
			log.Println(err)
		}

		closeFile()
		more := nextFile()
		if !more {
			close(buffer)
			return
		}
	}
}

//export paddle_new_writer
func paddle_new_writer(path *C.char) C.writer {
	p := C.GoString(path)
	f, err := os.Create(p)
	if err != nil {
		log.Println(err)
		return -1
	}

	w := recordio.NewWriter(f, -1, -1)
	writer := &writer{f: f, w: w}
	return addWriter(writer)
}

func cArrayToSlice(p unsafe.Pointer, len int) []byte {
	if p == nullPtr {
		return nil
	}

	// create a Go clice backed by a C array, reference:
	// https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
	//
	// Go garbage collector will not interact with this data, need
	// to be freed from C side.
	return (*[1 << 30]byte)(p)[:len:len]
}

//export paddle_writer_write
func paddle_writer_write(writer C.writer, buf *C.uchar, size C.int) int {
	w := getWriter(writer)
	b := cArrayToSlice(unsafe.Pointer(buf), int(size))
	_, err := w.w.Write(b)
	if err != nil {
		log.Println(err)
		return -1
	}

	return 0
}

//export paddle_writer_release
func paddle_writer_release(writer C.writer) {
	w := removeWriter(writer)
	w.w.Close()
	w.f.Close()
}

//export paddle_new_reader
func paddle_new_reader(path *C.char, bufferSize C.int) C.reader {
	p := C.GoString(path)
	ss := strings.Split(p, ",")
	var paths []string
	for _, s := range ss {
		match, err := filepath.Glob(s)
		if err != nil {
			log.Printf("error applying glob to %s: %v\n", s, err)
			return -1
		}

		paths = append(paths, match...)
	}

	if len(paths) == 0 {
		log.Println("no valid path provided.", p)
		return -1
	}

	buffer := make(chan []byte, int(bufferSize))
	cancel := make(chan struct{})
	r := &reader{buffer: buffer, cancel: cancel}
	go read(paths, buffer, cancel)
	return addReader(r)
}

//export paddle_reader_next_item
func paddle_reader_next_item(reader C.reader, size *C.int) *C.uchar {
	r := getReader(reader)
	buf, ok := <-r.buffer
	if !ok {
		// channel closed and empty, reached EOF.
		*size = -1
		return (*C.uchar)(nullPtr)
	}

	if len(buf) == 0 {
		// empty item
		*size = 0
		return (*C.uchar)(nullPtr)
	}

	ptr := C.malloc(C.size_t(len(buf)))
	C.memcpy(ptr, unsafe.Pointer(&buf[0]), C.size_t(len(buf)))
	*size = C.int(len(buf))
	return (*C.uchar)(ptr)
}

//export paddle_reader_release
func paddle_reader_release(reader C.reader) {
	r := removeReader(reader)
	close(r.cancel)
}

func main() {} // Required but ignored