diff --git a/python/paddle/reader/__init__.py b/python/paddle/reader/__init__.py index 493b410e8299ebe167be43ead1401a6ab245a631..7373dc461b1d3115c03b37c5102a469a52aa7441 100644 --- a/python/paddle/reader/__init__.py +++ b/python/paddle/reader/__init__.py @@ -21,3 +21,5 @@ # # r = paddle.reader.buffered(paddle.reader.creator.text("hello.txt")) from decorator import * + +import creator diff --git a/python/paddle/reader/creator.py b/python/paddle/reader/creator.py new file mode 100644 index 0000000000000000000000000000000000000000..5a91bb0b8ef6d1874737386897f6c555eaec18d4 --- /dev/null +++ b/python/paddle/reader/creator.py @@ -0,0 +1,53 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ['np_array', 'text_file'] + + +def np_array(x): + """ + Creates a reader that yields elements of x, if it is a + numpy vector. Or rows of x, if it is a numpy matrix. + Or any sub-hyperplane indexed by the highest dimension. + + :param x: the numpy array to create reader from. + :returns: data reader created from x. + """ + + def reader(): + if x.ndim < 1: + yield x + + for e in x: + yield e + + return reader + + +def text_file(path): + """ + Creates a data reader that outputs text line by line from given text file. + Trailing new line ('\n') of each line will be removed. + + :path: path of the text file. + :returns: data reader of text file + """ + + def reader(): + f = open(path, "r") + for l in f: + yield l.rstrip('\n') + f.close() + + return reader diff --git a/python/paddle/reader/tests/CMakeLists.txt b/python/paddle/reader/tests/CMakeLists.txt index 502c897d8946a838847c1c23b1236358c58c088e..da072fb3dbeed5bee15fa1f64372ad8dec497070 100644 --- a/python/paddle/reader/tests/CMakeLists.txt +++ b/python/paddle/reader/tests/CMakeLists.txt @@ -2,3 +2,8 @@ add_test(NAME reader_decorator_test COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/ ${PYTHON_EXECUTABLE} ${PROJ_ROOT}/python/paddle/reader/tests/decorator_test.py WORKING_DIRECTORY ${PROJ_ROOT}/python/paddle) + +add_test(NAME reader_creator_test + COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/ + ${PYTHON_EXECUTABLE} ${PROJ_ROOT}/python/paddle/reader/tests/creator_test.py + WORKING_DIRECTORY ${PROJ_ROOT}/python/paddle) diff --git a/python/paddle/reader/tests/creator_test.py b/python/paddle/reader/tests/creator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..eda8ab6715b2be0c9cb6163adf60d8fbdf2d7e8c --- /dev/null +++ b/python/paddle/reader/tests/creator_test.py @@ -0,0 +1,38 @@ +# Copyright PaddlePaddle contributors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +import paddle.reader.creator +import numpy as np +import os + + +class TestNumpyArray(unittest.TestCase): + def test_numpy_array(self): + l = [[1, 2, 3], [4, 5, 6]] + x = np.array(l, np.int32) + reader = paddle.reader.creator.np_array(x) + for idx, e in enumerate(reader()): + self.assertItemsEqual(e, l[idx]) + + +class TestTextFile(unittest.TestCase): + def test_text_file(self): + path = os.path.join(os.path.dirname(__file__), "test_data_creator.txt") + reader = paddle.reader.creator.text_file(path) + for idx, e in enumerate(reader()): + self.assertEqual(e, str(idx * 2) + " " + str(idx * 2 + 1)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/reader/tests/test_data_creator.txt b/python/paddle/reader/tests/test_data_creator.txt new file mode 100644 index 0000000000000000000000000000000000000000..a2a8d47d43868d369083808497697da79e620e31 --- /dev/null +++ b/python/paddle/reader/tests/test_data_creator.txt @@ -0,0 +1,3 @@ +0 1 +2 3 +4 5