diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index cf31b4a3429cc5d92fcde1118937c22cb0f34aee..9898dc083ebb1783a0e2ddd12afaa9c3d5a79e98 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -9,9 +9,10 @@ add_subdirectory(pserver) add_subdirectory(trainer) add_subdirectory(scripts) -if(CMAKE_Go_COMPILER) - add_subdirectory(go) -endif() +# Do not build go directory until go cmake is working smoothly. +# if(CMAKE_Go_COMPILER) +# add_subdirectory(go) +# endif() find_package(Boost QUIET) diff --git a/paddle/go/CMakeLists.txt b/paddle/go/CMakeLists.txt deleted file mode 100644 index 51c5252d66374fbc55abc0e8ede8fccd0f4dead7..0000000000000000000000000000000000000000 --- a/paddle/go/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -include_directories(${CMAKE_CURRENT_BINARY_DIR}) - -go_library(adder SRCS adder.go) - -if (WITH_TESTING) - cc_test(cgo_test - SRCS - cgo_test.cc - DEPS - adder) -endif() diff --git a/paddle/go/adder.go b/paddle/go/adder.go deleted file mode 100644 index e14f40fd9feb23aa55b71f3c422445b7fbfd827f..0000000000000000000000000000000000000000 --- a/paddle/go/adder.go +++ /dev/null @@ -1,10 +0,0 @@ -package main - -import "C" - -//export GoAdder -func GoAdder(x, y int) int { - return x + y -} - -func main() {} // Required but ignored diff --git a/paddle/go/cclient/CMakeLists.txt b/paddle/go/cclient/CMakeLists.txt index c85ff3db09d442a3e51f061993b5f02f3e69e2bb..dfd104fb589203e31ff183f134735fd302b263ab 100644 --- a/paddle/go/cclient/CMakeLists.txt +++ b/paddle/go/cclient/CMakeLists.txt @@ -1,31 +1,12 @@ cmake_minimum_required(VERSION 3.0) -if(GTEST_INCLUDE_DIR AND GTEST_LIBRARIES) - message("-- Found gtest (include: ${GTEST_INCLUDE_DIR}, library: ${GTEST_LIBRARIES})") -else() - # find cmake directory modules - get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY) - get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY) - get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY) +get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY) +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake") - set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake") +project(cxx_go C Go) - # enable c++11 - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") +include(golang) +include(flags) - # enable gtest - set(THIRD_PARTY_PATH ./third_party) - set(WITH_TESTING ON) - include(external/gtest) -endif() - -set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") - -project(cxx_go CXX C Go) - -include(cmake/golang.cmake) -include(cmake/flags.cmake) - -ExternalGoProject_Add(pserver github.com/PaddlePaddle/Paddle/paddle/go/pserver) -add_go_library(client STATIC pserver) +go_library(client STATIC) add_subdirectory(test) diff --git a/paddle/go/cclient/cclient.go b/paddle/go/cclient/cclient.go index dc86d47e8d0a97e3d78b174f84add8b9a3730f1f..ee2d9d24fd82fffd8b77210c6c167c3364cc6da2 100644 --- a/paddle/go/cclient/cclient.go +++ b/paddle/go/cclient/cclient.go @@ -78,8 +78,11 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte { return nil } - // create a Go clice backed by a C array, - // reference: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices + // create a Go clice backed by a C array, reference: + // https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices + // + // Go garbage collector will not interact with this data, need + // to be freed properly. return (*[1 << 30]byte)(p)[:len:len] } diff --git a/paddle/go/cclient/cmake/golang.cmake b/paddle/go/cclient/cmake/golang.cmake deleted file mode 100644 index 5d39868bfdfbfbeb88861c7829b6485589993052..0000000000000000000000000000000000000000 --- a/paddle/go/cclient/cmake/golang.cmake +++ /dev/null @@ -1,46 +0,0 @@ -set(GOPATH "${CMAKE_CURRENT_BINARY_DIR}/go") -file(MAKE_DIRECTORY ${GOPATH}) - -function(ExternalGoProject_Add TARG) - add_custom_target(${TARG} env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get ${ARGN}) -endfunction(ExternalGoProject_Add) - -function(add_go_executable NAME) - file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go") - add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp - COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build - -o "${CMAKE_CURRENT_BINARY_DIR}/${NAME}" - ${CMAKE_GO_FLAGS} ${GO_SOURCE} - WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}) - - add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN}) - install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${NAME} DESTINATION bin) -endfunction(add_go_executable) - - -function(ADD_GO_LIBRARY NAME BUILD_TYPE) - if(BUILD_TYPE STREQUAL "STATIC") - set(BUILD_MODE -buildmode=c-archive) - set(LIB_NAME "lib${NAME}.a") - else() - set(BUILD_MODE -buildmode=c-shared) - if(APPLE) - set(LIB_NAME "lib${NAME}.dylib") - else() - set(LIB_NAME "lib${NAME}.so") - endif() - endif() - - file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go") - add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp - COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE} - -o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}" - ${CMAKE_GO_FLAGS} ${GO_SOURCE} - WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}) - - add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN}) - - if(NOT BUILD_TYPE STREQUAL "STATIC") - install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME} DESTINATION bin) - endif() -endfunction(ADD_GO_LIBRARY) diff --git a/paddle/go/cclient/test/main.c b/paddle/go/cclient/test/main.c index 28e3d03b7a000d3251a8d525ce50ca664eff3424..abfb32e5603f5b51036f1f8475c8e3aca2f05ccb 100644 --- a/paddle/go/cclient/test/main.c +++ b/paddle/go/cclient/test/main.c @@ -1,11 +1,12 @@ -#include "libclient.h" +#include -//#include "gtest/gtest.h" +#include "libclient.h" -void panic() { +void fail() { // TODO(helin): fix: gtest using cmake is not working, using this // hacky way for now. - *(void*)0; + printf("test failed.\n"); + exit(-1); } int main() { @@ -35,7 +36,7 @@ retry: goto retry; } } else { - panic(); + fail(); } char content[] = {0x00, 0x11, 0x22}; @@ -44,25 +45,25 @@ retry: {"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3}}; if (!paddle_send_grads(c, grads, 2)) { - panic(); + fail(); } paddle_parameter* params[2] = {NULL, NULL}; char* names[] = {"param_a", "param_b"}; if (!paddle_get_params(c, names, params, 2)) { - panic(); + fail(); } // get parameters again by reusing the allocated parameter buffers. if (!paddle_get_params(c, names, params, 2)) { - panic(); + fail(); } paddle_release_param(params[0]); paddle_release_param(params[1]); if (!paddle_save_model(c, "/tmp/")) { - panic(); + fail(); } return 0; diff --git a/paddle/go/cgo_test.cc b/paddle/go/cgo_test.cc deleted file mode 100644 index 64efa606fff260485c375b961d5e485296edfe2c..0000000000000000000000000000000000000000 --- a/paddle/go/cgo_test.cc +++ /dev/null @@ -1,5 +0,0 @@ -#include -#include "gtest/gtest.h" -#include "libadder.h" - -TEST(Cgo, Invoke) { EXPECT_EQ(GoAdder(30, 12), 42); } diff --git a/paddle/go/cclient/cmake/CMakeDetermineGoCompiler.cmake b/paddle/go/cmake/CMakeDetermineGoCompiler.cmake similarity index 94% rename from paddle/go/cclient/cmake/CMakeDetermineGoCompiler.cmake rename to paddle/go/cmake/CMakeDetermineGoCompiler.cmake index b3f8fbe271d80aaa72d90d167a0d8130bec7f362..a9bb6906c7440782bd648bb7505a548248a11bb0 100644 --- a/paddle/go/cclient/cmake/CMakeDetermineGoCompiler.cmake +++ b/paddle/go/cmake/CMakeDetermineGoCompiler.cmake @@ -38,7 +38,7 @@ endif() mark_as_advanced(CMAKE_Go_COMPILER) -configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/CMakeGoCompiler.cmake.in +configure_file(${CMAKE_MODULE_PATH}/CMakeGoCompiler.cmake.in ${CMAKE_PLATFORM_INFO_DIR}/CMakeGoCompiler.cmake @ONLY) set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER") diff --git a/paddle/go/cclient/cmake/CMakeGoCompiler.cmake.in b/paddle/go/cmake/CMakeGoCompiler.cmake.in similarity index 100% rename from paddle/go/cclient/cmake/CMakeGoCompiler.cmake.in rename to paddle/go/cmake/CMakeGoCompiler.cmake.in diff --git a/paddle/go/cclient/cmake/CMakeGoInformation.cmake b/paddle/go/cmake/CMakeGoInformation.cmake similarity index 100% rename from paddle/go/cclient/cmake/CMakeGoInformation.cmake rename to paddle/go/cmake/CMakeGoInformation.cmake diff --git a/paddle/go/cclient/cmake/CMakeTestGoCompiler.cmake b/paddle/go/cmake/CMakeTestGoCompiler.cmake similarity index 100% rename from paddle/go/cclient/cmake/CMakeTestGoCompiler.cmake rename to paddle/go/cmake/CMakeTestGoCompiler.cmake diff --git a/paddle/go/cclient/cmake/flags.cmake b/paddle/go/cmake/flags.cmake similarity index 95% rename from paddle/go/cclient/cmake/flags.cmake rename to paddle/go/cmake/flags.cmake index 062d5ab660dad2327d9f514f22c2868cc0f161a7..a167c432a920e9ee93878603f3b946e8593412f6 100644 --- a/paddle/go/cclient/cmake/flags.cmake +++ b/paddle/go/cmake/flags.cmake @@ -21,7 +21,7 @@ function(CheckCompilerCXX11Flag) if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.3) message(FATAL_ERROR "Unsupported Clang version. Clang >= 3.3 required.") endif() - endif() + endif() endif() endfunction() @@ -42,4 +42,4 @@ if (CUDA_VERSION VERSION_GREATER "8.0" OR CUDA_VERSION VERSION_EQUAL "8.0") list(APPEND __arch_flags " -gencode arch=compute_60,code=sm_60") endif() -set(CUDA_NVCC_FLAGS ${__arch_flags} ${CUDA_NVCC_FLAGS}) \ No newline at end of file +set(CUDA_NVCC_FLAGS ${__arch_flags} ${CUDA_NVCC_FLAGS}) diff --git a/paddle/go/cmake/golang.cmake b/paddle/go/cmake/golang.cmake new file mode 100644 index 0000000000000000000000000000000000000000..e73b0c865bcf066302646713fa9311b3e3489235 --- /dev/null +++ b/paddle/go/cmake/golang.cmake @@ -0,0 +1,50 @@ +set(GOPATH "${CMAKE_CURRENT_BINARY_DIR}/go") +file(MAKE_DIRECTORY ${GOPATH}) +set(PADDLE_IN_GOPATH "${GOPATH}/src/github.com/PaddlePaddle") +file(MAKE_DIRECTORY ${PADDLE_IN_GOPATH}) + +function(GO_LIBRARY NAME BUILD_TYPE) + if(BUILD_TYPE STREQUAL "STATIC") + set(BUILD_MODE -buildmode=c-archive) + set(LIB_NAME "lib${NAME}.a") + else() + set(BUILD_MODE -buildmode=c-shared) + if(APPLE) + set(LIB_NAME "lib${NAME}.dylib") + else() + set(LIB_NAME "lib${NAME}.so") + endif() + endif() + + file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go") + file(RELATIVE_PATH rel ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}) + + # find Paddle directory. + get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY) + get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY) + get_filename_component(PADDLE_DIR ${PARENT_DIR} DIRECTORY) + + # automatically get all dependencies specified in the source code + # for given target. + add_custom_target(goGet env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get -d ${rel}/...) + + # make a symlink that references Paddle inside $GOPATH, so go get + # will use the local changes in Paddle rather than checkout Paddle + # in github. + add_custom_target(copyPaddle + COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH}) + add_dependencies(goGet copyPaddle) + + add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp + COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE} + -o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}" + ${CMAKE_GO_FLAGS} ${GO_SOURCE} + WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}) + + add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN}) + add_dependencies(${NAME} goGet) + + if(NOT BUILD_TYPE STREQUAL "STATIC") + install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME} DESTINATION bin) + endif() +endfunction(GO_LIBRARY) diff --git a/paddle/go/crecordio/CMakeLists.txt b/paddle/go/crecordio/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..c395fe0b4a483c5a7db282d9f02c19bb143c5aeb --- /dev/null +++ b/paddle/go/crecordio/CMakeLists.txt @@ -0,0 +1,12 @@ +cmake_minimum_required(VERSION 3.0) + +get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY) +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake") + +project(cxx_go C Go) + +include(golang) +include(flags) + +go_library(recordio STATIC) +add_subdirectory(test) diff --git a/paddle/go/crecordio/crecordio.go b/paddle/go/crecordio/crecordio.go new file mode 100644 index 0000000000000000000000000000000000000000..33f97de8cf73be01313e0a8b27815489dbaceb64 --- /dev/null +++ b/paddle/go/crecordio/crecordio.go @@ -0,0 +1,116 @@ +package main + +/* +#include + +typedef int reader; +typedef int writer; +*/ +import "C" + +import ( + "log" + "os" + "strings" + "unsafe" + + "github.com/PaddlePaddle/Paddle/paddle/go/recordio" +) + +var nullPtr = unsafe.Pointer(uintptr(0)) + +type writer struct { + w *recordio.Writer + f *os.File +} + +type reader struct { + scanner *recordio.Scanner +} + +func cArrayToSlice(p unsafe.Pointer, len int) []byte { + if p == nullPtr { + return nil + } + + // create a Go clice backed by a C array, reference: + // https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices + // + // Go garbage collector will not interact with this data, need + // to be freed properly. + return (*[1 << 30]byte)(p)[:len:len] +} + +//export create_recordio_writer +func create_recordio_writer(path *C.char) C.writer { + p := C.GoString(path) + f, err := os.Create(p) + if err != nil { + log.Println(err) + return -1 + } + + w := recordio.NewWriter(f, -1, -1) + writer := &writer{f: f, w: w} + return addWriter(writer) +} + +//export recordio_write +func recordio_write(writer C.writer, buf *C.uchar, size C.int) C.int { + w := getWriter(writer) + b := cArrayToSlice(unsafe.Pointer(buf), int(size)) + c, err := w.w.Write(b) + if err != nil { + log.Println(err) + return -1 + } + + return C.int(c) +} + +//export release_recordio_writer +func release_recordio_writer(writer C.writer) { + w := removeWriter(writer) + w.w.Close() + w.f.Close() +} + +//export create_recordio_reader +func create_recordio_reader(path *C.char) C.reader { + p := C.GoString(path) + s, err := recordio.NewScanner(strings.Split(p, ",")...) + if err != nil { + log.Println(err) + return -1 + } + + r := &reader{scanner: s} + return addReader(r) +} + +//export recordio_read +func recordio_read(reader C.reader, record **C.uchar) C.int { + r := getReader(reader) + if r.scanner.Scan() { + buf := r.scanner.Record() + if len(buf) == 0 { + *record = (*C.uchar)(nullPtr) + return 0 + } + + size := C.int(len(buf)) + *record = (*C.uchar)(C.malloc(C.size_t(len(buf)))) + C.memcpy(unsafe.Pointer(*record), unsafe.Pointer(&buf[0]), C.size_t(len(buf))) + return size + } + + return -1 +} + +//export release_recordio_reader +func release_recordio_reader(reader C.reader) { + r := removeReader(reader) + r.scanner.Close() +} + +func main() {} // Required but ignored diff --git a/paddle/go/crecordio/register.go b/paddle/go/crecordio/register.go new file mode 100644 index 0000000000000000000000000000000000000000..61dfdbd4ab64a05a25cc24219456853a010c4ce4 --- /dev/null +++ b/paddle/go/crecordio/register.go @@ -0,0 +1,61 @@ +package main + +/* +typedef int reader; +typedef int writer; +*/ +import "C" + +import "sync" + +var mu sync.Mutex +var handleMap = make(map[C.reader]*reader) +var curHandle C.reader +var writerMap = make(map[C.writer]*writer) +var curWriterHandle C.writer + +func addReader(r *reader) C.reader { + mu.Lock() + defer mu.Unlock() + reader := curHandle + curHandle++ + handleMap[reader] = r + return reader +} + +func getReader(reader C.reader) *reader { + mu.Lock() + defer mu.Unlock() + return handleMap[reader] +} + +func removeReader(reader C.reader) *reader { + mu.Lock() + defer mu.Unlock() + r := handleMap[reader] + delete(handleMap, reader) + return r +} + +func addWriter(w *writer) C.writer { + mu.Lock() + defer mu.Unlock() + writer := curWriterHandle + curWriterHandle++ + writerMap[writer] = w + return writer +} + +func getWriter(writer C.writer) *writer { + mu.Lock() + defer mu.Unlock() + return writerMap[writer] +} + +func removeWriter(writer C.writer) *writer { + mu.Lock() + defer mu.Unlock() + w := writerMap[writer] + delete(writerMap, writer) + return w +} diff --git a/paddle/go/crecordio/test/CMakeLists.txt b/paddle/go/crecordio/test/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..bac1006ae12e07574afaa4b00160b559d173c332 --- /dev/null +++ b/paddle/go/crecordio/test/CMakeLists.txt @@ -0,0 +1,8 @@ +cmake_minimum_required(VERSION 3.0) + +include_directories(${CMAKE_BINARY_DIR}) + +add_executable(recordio_test test.c) +add_dependencies(recordio_test recordio) +set (CMAKE_EXE_LINKER_FLAGS "-pthread") +target_link_libraries(recordio_test ${CMAKE_BINARY_DIR}/librecordio.a) diff --git a/paddle/go/crecordio/test/test.c b/paddle/go/crecordio/test/test.c new file mode 100644 index 0000000000000000000000000000000000000000..b25536a9d76a8654cf1b15075c76887495e1d9bd --- /dev/null +++ b/paddle/go/crecordio/test/test.c @@ -0,0 +1,56 @@ +#include +#include + +#include "librecordio.h" + +void fail() { + // TODO(helin): fix: gtest using cmake is not working, using this + // hacky way for now. + printf("test failed.\n"); + exit(-1); +} + +int main() { + writer w = create_recordio_writer("/tmp/test_recordio_0"); + recordio_write(w, "hello", 6); + recordio_write(w, "hi", 3); + release_recordio_writer(w); + + w = create_recordio_writer("/tmp/test_recordio_1"); + recordio_write(w, "dog", 4); + recordio_write(w, "cat", 4); + release_recordio_writer(w); + + reader r = create_recordio_reader("/tmp/test_recordio_*"); + unsigned char* item = NULL; + int size = recordio_read(r, &item); + if (strcmp(item, "hello") || size != 6) { + fail(); + } + free(item); + + size = recordio_read(r, &item); + if (strcmp(item, "hi") || size != 3) { + fail(); + } + free(item); + + size = recordio_read(r, &item); + if (strcmp(item, "dog") || size != 4) { + fail(); + } + free(item); + + size = recordio_read(r, &item); + if (strcmp(item, "cat") || size != 4) { + fail(); + } + free(item); + + size = recordio_read(r, &item); + if (size != -1) { + fail(); + } + + release_recordio_reader(r); +} diff --git a/paddle/go/recordio/README.md b/paddle/go/recordio/README.md index 8b0b9308b1ade3560d6bda150ea0139a9fb2503b..50e7e954764ec6f26397c6a24296b1bf65403d69 100644 --- a/paddle/go/recordio/README.md +++ b/paddle/go/recordio/README.md @@ -8,6 +8,7 @@ w := recordio.NewWriter(f) w.Write([]byte("Hello")) w.Write([]byte("World!")) w.Close() +f.Close() ``` ## Read @@ -18,6 +19,7 @@ w.Close() f, e := os.Open("a_file.recordio") idx, e := recordio.LoadIndex(f) fmt.Println("Total records: ", idx.Len()) + f.Close() ``` 2. Create one or more scanner to read a range of records. The @@ -30,7 +32,8 @@ w.Close() for s.Scan() { fmt.Println(string(s.Record())) } - if s.Err() != nil && s.Err() != io.EOF { + if s.Err() != nil { log.Fatalf("Something wrong with scanning: %v", e) } + f.Close() ``` diff --git a/paddle/go/recordio/reader.go b/paddle/go/recordio/range_scanner.go similarity index 84% rename from paddle/go/recordio/reader.go rename to paddle/go/recordio/range_scanner.go index a12c604f7b2f5c103624aac538034ec6a883c536..46e2eee68c7b7fc6bb1b69f60a75fd85cfe85576 100644 --- a/paddle/go/recordio/reader.go +++ b/paddle/go/recordio/range_scanner.go @@ -74,8 +74,8 @@ func (r *Index) Locate(recordIndex int) (int, int) { return -1, -1 } -// Scanner scans records in a specified range within [0, numRecords). -type Scanner struct { +// RangeScanner scans records in a specified range within [0, numRecords). +type RangeScanner struct { reader io.ReadSeeker index *Index start, end, cur int @@ -84,10 +84,10 @@ type Scanner struct { err error } -// NewScanner creates a scanner that sequencially reads records in the +// NewRangeScanner creates a scanner that sequencially reads records in the // range [start, start+len). If start < 0, it scans from the // beginning. If len < 0, it scans till the end of file. -func NewScanner(r io.ReadSeeker, index *Index, start, len int) *Scanner { +func NewRangeScanner(r io.ReadSeeker, index *Index, start, len int) *RangeScanner { if start < 0 { start = 0 } @@ -95,7 +95,7 @@ func NewScanner(r io.ReadSeeker, index *Index, start, len int) *Scanner { len = index.NumRecords() - start } - return &Scanner{ + return &RangeScanner{ reader: r, index: index, start: start, @@ -108,7 +108,7 @@ func NewScanner(r io.ReadSeeker, index *Index, start, len int) *Scanner { // Scan moves the cursor forward for one record and loads the chunk // containing the record if not yet. -func (s *Scanner) Scan() bool { +func (s *RangeScanner) Scan() bool { s.cur++ if s.cur >= s.end { @@ -124,12 +124,17 @@ func (s *Scanner) Scan() bool { } // Record returns the record under the current cursor. -func (s *Scanner) Record() []byte { +func (s *RangeScanner) Record() []byte { _, ri := s.index.Locate(s.cur) return s.chunk.records[ri] } -// Error returns the error that stopped Scan. -func (s *Scanner) Error() error { +// Err returns the first non-EOF error that was encountered by the +// Scanner. +func (s *RangeScanner) Err() error { + if s.err == io.EOF { + return nil + } + return s.err } diff --git a/paddle/go/recordio/recordio_internal_test.go b/paddle/go/recordio/recordio_internal_test.go index e0f7dd0407caaf38e8113660239d1a0c6eb8afa1..30e317925d8c95e64a42bd8ac5a1dd43b95ee81d 100644 --- a/paddle/go/recordio/recordio_internal_test.go +++ b/paddle/go/recordio/recordio_internal_test.go @@ -68,7 +68,7 @@ func TestWriteAndRead(t *testing.T) { 2*4)}, // two record legnths idx.chunkOffsets) - s := NewScanner(bytes.NewReader(buf.Bytes()), idx, -1, -1) + s := NewRangeScanner(bytes.NewReader(buf.Bytes()), idx, -1, -1) i := 0 for s.Scan() { assert.Equal(data[i], string(s.Record())) diff --git a/paddle/go/recordio/recordio_test.go b/paddle/go/recordio/recordio_test.go index 8bf1b020ab75ca66c12b713526e010756c364217..ab117d2050e6ac18ef63021081a684f259299803 100644 --- a/paddle/go/recordio/recordio_test.go +++ b/paddle/go/recordio/recordio_test.go @@ -29,7 +29,7 @@ func TestWriteRead(t *testing.T) { t.Fatal("num record does not match:", idx.NumRecords(), total) } - s := recordio.NewScanner(bytes.NewReader(buf.Bytes()), idx, -1, -1) + s := recordio.NewRangeScanner(bytes.NewReader(buf.Bytes()), idx, -1, -1) i := 0 for s.Scan() { if !reflect.DeepEqual(s.Record(), make([]byte, i)) { @@ -66,7 +66,7 @@ func TestChunkIndex(t *testing.T) { for i := 0; i < total; i++ { newIdx := idx.ChunkIndex(i) - s := recordio.NewScanner(bytes.NewReader(buf.Bytes()), newIdx, -1, -1) + s := recordio.NewRangeScanner(bytes.NewReader(buf.Bytes()), newIdx, -1, -1) j := 0 for s.Scan() { if !reflect.DeepEqual(s.Record(), make([]byte, i)) { diff --git a/paddle/go/recordio/scanner.go b/paddle/go/recordio/scanner.go new file mode 100644 index 0000000000000000000000000000000000000000..865228ff651c6eee2cf1fa05ec38a4964394b6dc --- /dev/null +++ b/paddle/go/recordio/scanner.go @@ -0,0 +1,140 @@ +package recordio + +import ( + "fmt" + "os" + "path/filepath" +) + +// Scanner is a scanner for multiple recordio files. +type Scanner struct { + paths []string + curFile *os.File + curScanner *RangeScanner + pathIdx int + end bool + err error +} + +// NewScanner creates a new Scanner. +func NewScanner(paths ...string) (*Scanner, error) { + var ps []string + for _, s := range paths { + match, err := filepath.Glob(s) + if err != nil { + return nil, err + } + + ps = append(ps, match...) + } + + if len(ps) == 0 { + return nil, fmt.Errorf("no valid path provided: %v", paths) + } + + return &Scanner{paths: ps}, nil +} + +// Scan moves the cursor forward for one record and loads the chunk +// containing the record if not yet. +func (s *Scanner) Scan() bool { + if s.err != nil { + return false + } + + if s.end { + return false + } + + if s.curScanner == nil { + more, err := s.nextFile() + if err != nil { + s.err = err + return false + } + + if !more { + s.end = true + return false + } + } + + curMore := s.curScanner.Scan() + s.err = s.curScanner.Err() + + if s.err != nil { + return curMore + } + + if !curMore { + err := s.curFile.Close() + if err != nil { + s.err = err + return false + } + s.curFile = nil + + more, err := s.nextFile() + if err != nil { + s.err = err + return false + } + + if !more { + s.end = true + return false + } + + return s.Scan() + } + return true +} + +// Err returns the first non-EOF error that was encountered by the +// Scanner. +func (s *Scanner) Err() error { + return s.err +} + +// Record returns the record under the current cursor. +func (s *Scanner) Record() []byte { + if s.curScanner == nil { + return nil + } + + return s.curScanner.Record() +} + +// Close release the resources. +func (s *Scanner) Close() error { + s.curScanner = nil + if s.curFile != nil { + err := s.curFile.Close() + s.curFile = nil + return err + } + return nil +} + +func (s *Scanner) nextFile() (bool, error) { + if s.pathIdx >= len(s.paths) { + return false, nil + } + + path := s.paths[s.pathIdx] + s.pathIdx++ + f, err := os.Open(path) + if err != nil { + return false, err + } + + idx, err := LoadIndex(f) + if err != nil { + f.Close() + return false, err + } + + s.curFile = f + s.curScanner = NewRangeScanner(f, idx, 0, -1) + return true, nil +}