diff --git a/python/paddle/dataset/flowers.py b/python/paddle/dataset/flowers.py index 914dae348bc94d061072543aa14aba2219f4b52d..7d14cc5dc8796c0865ba1b25405a6af6e203c94a 100644 --- a/python/paddle/dataset/flowers.py +++ b/python/paddle/dataset/flowers.py @@ -116,8 +116,8 @@ def reader_creator(data_file, for file in open(file_list): file = file.strip() batch = None - with open(file, 'r') as f: - batch = pickle.load(f) + with open(file, 'rb') as f: + batch = pickle.loads(f.read()) data = batch['data'] labels = batch['label'] for sample, label in zip(data, batch['label']): diff --git a/python/paddle/reader/tests/creator_test.py b/python/paddle/reader/tests/creator_test.py index c4238c12a74759d52eb09f31ce1126cc93dd3489..567f38c96e73ae9bc333bbda642c423ea0742d34 100644 --- a/python/paddle/reader/tests/creator_test.py +++ b/python/paddle/reader/tests/creator_test.py @@ -29,6 +29,7 @@ import os import unittest import numpy as np import paddle.reader.creator +import six class TestNumpyArray(unittest.TestCase): @@ -37,7 +38,7 @@ class TestNumpyArray(unittest.TestCase): x = np.array(l, np.int32) reader = paddle.reader.creator.np_array(x) for idx, e in enumerate(reader()): - self.assertItemsEqual(e, l[idx]) + six.assertCountEqual(e, l[idx]) class TestTextFile(unittest.TestCase):