crecordio.go 2.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
package main

/*
#include <string.h>

typedef int reader;
typedef int writer;
*/
import "C"

import (
	"log"
	"os"
	"strings"
	"unsafe"

	"github.com/PaddlePaddle/Paddle/paddle/go/recordio"
)

var nullPtr = unsafe.Pointer(uintptr(0))

type writer struct {
	w *recordio.Writer
	f *os.File
}

type reader struct {
H
Helin Wang 已提交
28
	scanner *recordio.MultiScanner
29 30
}

H
Helin Wang 已提交
31 32 33
func cArrayToSlice(p unsafe.Pointer, len int) []byte {
	if p == nullPtr {
		return nil
34 35
	}

H
Helin Wang 已提交
36 37 38 39 40 41
	// create a Go clice backed by a C array, reference:
	// https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
	//
	// Go garbage collector will not interact with this data, need
	// to be freed properly.
	return (*[1 << 30]byte)(p)[:len:len]
42 43
}

H
Helin Wang 已提交
44 45
//export create_recordio_writer
func create_recordio_writer(path *C.char) C.writer {
46 47 48 49 50 51 52 53 54 55 56 57
	p := C.GoString(path)
	f, err := os.Create(p)
	if err != nil {
		log.Println(err)
		return -1
	}

	w := recordio.NewWriter(f, -1, -1)
	writer := &writer{f: f, w: w}
	return addWriter(writer)
}

H
Helin Wang 已提交
58 59
//export write_recordio
func write_recordio(writer C.writer, buf *C.uchar, size C.int) int {
60 61 62 63 64 65 66 67 68 69 70
	w := getWriter(writer)
	b := cArrayToSlice(unsafe.Pointer(buf), int(size))
	_, err := w.w.Write(b)
	if err != nil {
		log.Println(err)
		return -1
	}

	return 0
}

H
Helin Wang 已提交
71 72
//export release_recordio
func release_recordio(writer C.writer) {
73 74 75 76 77
	w := removeWriter(writer)
	w.w.Close()
	w.f.Close()
}

H
Helin Wang 已提交
78 79
//export create_recordio_reader
func create_recordio_reader(path *C.char) C.reader {
80
	p := C.GoString(path)
H
Helin Wang 已提交
81 82 83
	s, err := recordio.NewMultiScanner(strings.Split(p, ","))
	if err != nil {
		log.Println(err)
84 85 86
		return -1
	}

H
Helin Wang 已提交
87
	r := &reader{scanner: s}
88 89 90
	return addReader(r)
}

H
Helin Wang 已提交
91 92
//export read_next_item
func read_next_item(reader C.reader, size *C.int) *C.uchar {
93
	r := getReader(reader)
H
Helin Wang 已提交
94 95 96 97 98 99 100
	if r.scanner.Scan() {
		buf := r.scanner.Record()
		*size = C.int(len(buf))

		if len(buf) == 0 {
			return (*C.uchar)(nullPtr)
		}
101

H
Helin Wang 已提交
102 103 104
		ptr := C.malloc(C.size_t(len(buf)))
		C.memcpy(ptr, unsafe.Pointer(&buf[0]), C.size_t(len(buf)))
		return (*C.uchar)(ptr)
105 106
	}

H
Helin Wang 已提交
107 108
	*size = -1
	return (*C.uchar)(nullPtr)
109 110
}

H
Helin Wang 已提交
111 112
//export release_recordio_reader
func release_recordio_reader(reader C.reader) {
113
	r := removeReader(reader)
H
Helin Wang 已提交
114
	r.scanner.Close()
115 116 117
}

func main() {} // Required but ignored