diff --git a/paddle/trainer/tests/pydata_provider_wrapper_dir/test_pydata_provider_wrapper.proto_data b/paddle/trainer/tests/pydata_provider_wrapper_dir/test_pydata_provider_wrapper.proto_data deleted file mode 100644 index f189b21e86a50d70d317b5e43aa2d6e05af5e774..0000000000000000000000000000000000000000 Binary files a/paddle/trainer/tests/pydata_provider_wrapper_dir/test_pydata_provider_wrapper.proto_data and /dev/null differ diff --git a/paddle/trainer/tests/pydata_provider_wrapper_dir/test_pydata_provider_wrapper.protolist b/paddle/trainer/tests/pydata_provider_wrapper_dir/test_pydata_provider_wrapper.protolist deleted file mode 100644 index 6b406dff0ba91b5f310d7eafa111c0d21d6542c3..0000000000000000000000000000000000000000 --- a/paddle/trainer/tests/pydata_provider_wrapper_dir/test_pydata_provider_wrapper.protolist +++ /dev/null @@ -1 +0,0 @@ -./trainer/tests/pydata_provider_wrapper_dir/test_pydata_provider_wrapper.proto_data diff --git a/paddle/trainer/tests/testPyDataWrapper.py b/paddle/trainer/tests/testPyDataWrapper.py index 2c29a274339747b78fbd6c27ae4070f0abbd4028..a76eeeacb91cdba305d2f71c6292f79e4b98dd73 100644 --- a/paddle/trainer/tests/testPyDataWrapper.py +++ b/paddle/trainer/tests/testPyDataWrapper.py @@ -20,28 +20,6 @@ import random import json import string - -@provider(slots=[ - SparseNonValueSlot(10), DenseSlot(2), SparseValueSlot(10), StringSlot(1), - IndexSlot(3) -]) -def processNonSequenceData(obj, filename): - with open(filename, "rb") as f: - for line in f: - slots_str = line.split(';') - index = int(slots_str[0]) - non_values = map(int, slots_str[1].split()[1:]) - dense = map(float, slots_str[2].split()[1:]) - strs = slots_str[4].strip().split(' ', 1)[1] - - def __values_mapper__(s): - s = s.split(":") - return int(s[0]), float(s[1]) - - values = map(__values_mapper__, slots_str[3].split()[1:]) - yield [non_values, dense, values, strs, index] - - SPARSE_ID_LIMIT = 1000 SPARSE_ID_COUNT = 100 SEQUENCE_LIMIT = 50 @@ -146,8 +124,6 @@ def processSubSeqAndGenerateData(obj, name): if __name__ == "__main__": - pvd = processNonSequenceData("test.txt") - print pvd.getNextBatch(100) pvd = processSeqAndGenerateData("_") print pvd.getNextBatch(100) pvd = processSubSeqAndGenerateData("_") diff --git a/paddle/trainer/tests/test_PyDataProviderWrapper.cpp b/paddle/trainer/tests/test_PyDataProviderWrapper.cpp index 66ec65e340a435a7260028611828fb28845e0728..92dc8aa9ec5ce281d1950d84260c1b9555e686a7 100644 --- a/paddle/trainer/tests/test_PyDataProviderWrapper.cpp +++ b/paddle/trainer/tests/test_PyDataProviderWrapper.cpp @@ -25,45 +25,9 @@ limitations under the License. */ #include #include "picojson.h" -void checkEqual(const paddle::Argument& expect, const paddle::Argument& actual); void checkValue(std::vector& arguments, picojson::array& arr); const std::string kDir = "./trainer/tests/pydata_provider_wrapper_dir/"; -TEST(PyDataProviderWrapper, NoSequenceData) { - paddle::DataConfig conf; - conf.set_type("py"); - conf.set_load_data_module(std::string("testPyDataWrapper")); - conf.set_load_data_object(std::string("processNonSequenceData")); - conf.set_async_load_data(false); - conf.clear_files(); - conf.set_files(kDir + "test_pydata_provider_wrapper.list"); - paddle::DataProviderPtr provider(paddle::DataProvider::create(conf, false)); - provider->setSkipShuffle(); - provider->reset(); - paddle::DataBatch batchFromPy; - provider->getNextBatch(100, &batchFromPy); - - paddle::DataConfig conf2; - conf2.set_type("proto"); - conf2.set_async_load_data(false); - conf2.clear_files(); - conf2.set_files(kDir + "test_pydata_provider_wrapper.protolist"); - - provider.reset(paddle::DataProvider::create(conf2, false)); - provider->setSkipShuffle(); - provider->reset(); - paddle::DataBatch batchFromProto; - provider->getNextBatch(100, &batchFromProto); - - std::vector& pyArguments = batchFromPy.getStreams(); - std::vector& protoArguments = batchFromProto.getStreams(); - EXPECT_EQ(pyArguments.size(), protoArguments.size()); - - for (size_t i = 0; i < pyArguments.size(); ++i) { - checkEqual(protoArguments[i], pyArguments[i]); - } -} - TEST(PyDataProviderWrapper, SequenceData) { paddle::DataConfig conf; conf.set_type("py"); @@ -148,66 +112,6 @@ int main(int argc, char** argv) { return RUN_ALL_TESTS(); } -void checkEqual(const paddle::Argument& expect, - const paddle::Argument& actual) { - if (expect.value) { - EXPECT_TRUE(actual.value != nullptr); - paddle::Matrix* e = expect.value.get(); - paddle::Matrix* a = actual.value.get(); - EXPECT_EQ(e->getWidth(), a->getWidth()); - EXPECT_EQ(e->getHeight(), a->getHeight()); - if (dynamic_cast(e)) { - paddle::CpuSparseMatrix* se = dynamic_cast(e); - paddle::CpuSparseMatrix* sa = dynamic_cast(a); - EXPECT_EQ(se->getFormat(), sa->getFormat()); - EXPECT_EQ(se->getElementCnt(), sa->getElementCnt()); - size_t rowSize = se->getFormat() == paddle::SPARSE_CSC - ? se->getElementCnt() - : se->getHeight() + 1; - size_t colSize = se->getFormat() == paddle::SPARSE_CSC - ? se->getWidth() + 1 - : se->getElementCnt(); - for (size_t i = 0; i < rowSize; ++i) { - EXPECT_EQ(se->getRows()[i], sa->getRows()[i]); - } - for (size_t i = 0; i < colSize; ++i) { - EXPECT_EQ(se->getCols()[i], sa->getCols()[i]); - } - if (se->getValueType() == paddle::FLOAT_VALUE) { - EXPECT_EQ(paddle::FLOAT_VALUE, sa->getValueType()); - for (size_t i = 0; i < se->getElementCnt(); ++i) { - EXPECT_EQ(se->getValue()[i], sa->getValue()[i]); - } - } - } else if (dynamic_cast(e)) { - EXPECT_EQ(e->getElementCnt(), a->getElementCnt()); - for (size_t i = 0; i < e->getElementCnt(); ++i) { - EXPECT_EQ(e->getData()[i], a->getData()[i]); - } - } - } - - if (expect.ids) { - EXPECT_TRUE(actual.ids != nullptr); - paddle::VectorT* e = expect.ids.get(); - paddle::VectorT* a = actual.ids.get(); - EXPECT_EQ(e->getSize(), a->getSize()); - for (size_t i = 0; i < e->getSize(); ++i) { - EXPECT_EQ(e->getData()[i], a->getData()[i]); - } - } - - if (expect.strs) { - EXPECT_TRUE(actual.strs != nullptr); - std::vector* e = expect.strs.get(); - std::vector* a = actual.strs.get(); - EXPECT_EQ(e->size(), a->size()); - for (size_t i = 0; i < e->size(); ++i) { - EXPECT_EQ((*e)[i], (*a)[i]); - } - } -} - void checkValue(std::vector& arguments, picojson::array& arr) { // CHECK SLOT 0, Sparse Value.