crecordio.go 2.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
package main

/*
#include <string.h>

typedef int reader;
typedef int writer;
*/
import "C"

import (
	"log"
	"os"
	"strings"
	"unsafe"

17
	"github.com/PaddlePaddle/Paddle/go/recordio"
18 19 20 21 22 23 24 25 26 27
)

var nullPtr = unsafe.Pointer(uintptr(0))

type writer struct {
	w *recordio.Writer
	f *os.File
}

type reader struct {
H
Helin Wang 已提交
28
	scanner *recordio.Scanner
29 30
}

H
Helin Wang 已提交
31 32 33
func cArrayToSlice(p unsafe.Pointer, len int) []byte {
	if p == nullPtr {
		return nil
34 35
	}

H
Helin Wang 已提交
36 37 38 39 40 41
	// create a Go clice backed by a C array, reference:
	// https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
	//
	// Go garbage collector will not interact with this data, need
	// to be freed properly.
	return (*[1 << 30]byte)(p)[:len:len]
42 43
}

H
Helin Wang 已提交
44 45
//export create_recordio_writer
func create_recordio_writer(path *C.char) C.writer {
46 47 48 49 50 51 52 53 54 55 56 57
	p := C.GoString(path)
	f, err := os.Create(p)
	if err != nil {
		log.Println(err)
		return -1
	}

	w := recordio.NewWriter(f, -1, -1)
	writer := &writer{f: f, w: w}
	return addWriter(writer)
}

H
Helin Wang 已提交
58 59
//export recordio_write
func recordio_write(writer C.writer, buf *C.uchar, size C.int) C.int {
60 61
	w := getWriter(writer)
	b := cArrayToSlice(unsafe.Pointer(buf), int(size))
H
Helin Wang 已提交
62
	c, err := w.w.Write(b)
63 64 65 66 67
	if err != nil {
		log.Println(err)
		return -1
	}

H
Helin Wang 已提交
68
	return C.int(c)
69 70
}

H
Helin Wang 已提交
71 72
//export release_recordio_writer
func release_recordio_writer(writer C.writer) {
73 74 75 76 77
	w := removeWriter(writer)
	w.w.Close()
	w.f.Close()
}

H
Helin Wang 已提交
78 79
//export create_recordio_reader
func create_recordio_reader(path *C.char) C.reader {
80
	p := C.GoString(path)
H
Helin Wang 已提交
81
	s, err := recordio.NewScanner(strings.Split(p, ",")...)
H
Helin Wang 已提交
82 83
	if err != nil {
		log.Println(err)
84 85 86
		return -1
	}

H
Helin Wang 已提交
87
	r := &reader{scanner: s}
88 89 90
	return addReader(r)
}

H
Helin Wang 已提交
91 92
//export recordio_read
func recordio_read(reader C.reader, record **C.uchar) C.int {
93
	r := getReader(reader)
H
Helin Wang 已提交
94 95 96
	if r.scanner.Scan() {
		buf := r.scanner.Record()
		if len(buf) == 0 {
H
Helin Wang 已提交
97 98
			*record = (*C.uchar)(nullPtr)
			return 0
H
Helin Wang 已提交
99
		}
100

H
Helin Wang 已提交
101 102 103 104
		size := C.int(len(buf))
		*record = (*C.uchar)(C.malloc(C.size_t(len(buf))))
		C.memcpy(unsafe.Pointer(*record), unsafe.Pointer(&buf[0]), C.size_t(len(buf)))
		return size
105 106
	}

H
Helin Wang 已提交
107
	return -1
108 109
}

H
Helin Wang 已提交
110 111
//export release_recordio_reader
func release_recordio_reader(reader C.reader) {
112
	r := removeReader(reader)
H
Helin Wang 已提交
113
	r.scanner.Close()
114 115 116
}

func main() {} // Required but ignored