提交 e6ae1e4f 编写于 作者: M minqiyang

Replace the dependency of paddle.v2 dataset

上级 6abe819f
...@@ -53,7 +53,7 @@ def reader_creator(filename, sub_name, cycle=False): ...@@ -53,7 +53,7 @@ def reader_creator(filename, sub_name, cycle=False):
yield (sample / 255.0).astype(numpy.float32), int(label) yield (sample / 255.0).astype(numpy.float32), int(label)
def reader(): def reader():
with tarfile.open(filename, mode='rb') as f: with tarfile.open(filename, mode='r') as f:
names = (each_item.name for each_item in f names = (each_item.name for each_item in f
if sub_name in each_item.name) if sub_name in each_item.name)
......
...@@ -19,10 +19,12 @@ http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz and ...@@ -19,10 +19,12 @@ http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz and
parse training set and test set into paddle reader creators. parse training set and test set into paddle reader creators.
""" """
import six
import tarfile import tarfile
import gzip import gzip
import paddle.dataset.common import paddle.dataset.common
import paddle.fluid.compat as cpt
__all__ = [ __all__ = [
'train', 'train',
...@@ -40,8 +42,8 @@ URL_TRAIN = ('http://paddlepaddle.cdn.bcebos.com/demo/' ...@@ -40,8 +42,8 @@ URL_TRAIN = ('http://paddlepaddle.cdn.bcebos.com/demo/'
'wmt_shrinked_data/wmt14.tgz') 'wmt_shrinked_data/wmt14.tgz')
MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c' MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c'
# BLEU of this trained model is 26.92 # BLEU of this trained model is 26.92
URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz' URL_MODEL = 'http://paddlemodels.bj.bcebos.com/wmt/wmt14.tgz'
MD5_MODEL = '0cb4a5366189b6acba876491c8724fa3' MD5_MODEL = '0791583d57d5beb693b9414c5b36798c'
START = "<s>" START = "<s>"
END = "<e>" END = "<e>"
...@@ -54,7 +56,7 @@ def __read_to_dict(tar_file, dict_size): ...@@ -54,7 +56,7 @@ def __read_to_dict(tar_file, dict_size):
out_dict = dict() out_dict = dict()
for line_count, line in enumerate(fd): for line_count, line in enumerate(fd):
if line_count < size: if line_count < size:
out_dict[line.strip()] = line_count out_dict[cpt.to_literal_str(line.strip())] = line_count
else: else:
break break
return out_dict return out_dict
...@@ -85,7 +87,7 @@ def reader_creator(tar_file, file_name, dict_size): ...@@ -85,7 +87,7 @@ def reader_creator(tar_file, file_name, dict_size):
] ]
for name in names: for name in names:
for line in f.extractfile(name): for line in f.extractfile(name):
line_split = line.strip().split('\t') line_split = line.strip().split(six.b('\t'))
if len(line_split) != 2: if len(line_split) != 2:
continue continue
src_seq = line_split[0] # one source sequence src_seq = line_split[0] # one source sequence
......
...@@ -35,6 +35,7 @@ import gzip ...@@ -35,6 +35,7 @@ import gzip
from collections import defaultdict from collections import defaultdict
import paddle.dataset.common import paddle.dataset.common
import paddle.fluid.compat as cpt
__all__ = [ __all__ = [
"train", "train",
...@@ -82,16 +83,16 @@ def __load_dict(tar_file, dict_size, lang, reverse=False): ...@@ -82,16 +83,16 @@ def __load_dict(tar_file, dict_size, lang, reverse=False):
dict_path = os.path.join(paddle.dataset.common.DATA_HOME, dict_path = os.path.join(paddle.dataset.common.DATA_HOME,
"wmt16/%s_%d.dict" % (lang, dict_size)) "wmt16/%s_%d.dict" % (lang, dict_size))
if not os.path.exists(dict_path) or ( if not os.path.exists(dict_path) or (
len(open(dict_path, "r").readlines()) != dict_size): len(open(dict_path, "rb").readlines()) != dict_size):
__build_dict(tar_file, dict_size, dict_path, lang) __build_dict(tar_file, dict_size, dict_path, lang)
word_dict = {} word_dict = {}
with open(dict_path, "r") as fdict: with open(dict_path, "rb") as fdict:
for idx, line in enumerate(fdict): for idx, line in enumerate(fdict):
if reverse: if reverse:
word_dict[idx] = line.strip() word_dict[idx] = cpt.to_literal_str(line.strip())
else: else:
word_dict[line.strip()] = idx word_dict[cpt.to_literal_str(line.strip())] = idx
return word_dict return word_dict
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle as paddle import paddle
import numpy as np import numpy as np
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
import unittest import unittest
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.v2 as paddle import paddle.dataset.mnist as mnist
import paddle.v2.dataset.mnist as mnist
class TestPreprocessor(unittest.TestCase): class TestPreprocessor(unittest.TestCase):
......
...@@ -93,7 +93,7 @@ class TestProfiler(unittest.TestCase): ...@@ -93,7 +93,7 @@ class TestProfiler(unittest.TestCase):
"profiler is enabled only with GPU") "profiler is enabled only with GPU")
def test_all_profiler(self): def test_all_profiler(self):
self.net_profiler('All', '/tmp/profile_out') self.net_profiler('All', '/tmp/profile_out')
with open('/tmp/profile_out', 'r') as f: with open('/tmp/profile_out', 'rb') as f:
self.assertGreater(len(f.read()), 0) self.assertGreater(len(f.read()), 0)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import unittest import unittest
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid.compat as cpt
from paddle.fluid.framework import Program from paddle.fluid.framework import Program
...@@ -108,7 +109,7 @@ class TestVarDesc(unittest.TestCase): ...@@ -108,7 +109,7 @@ class TestVarDesc(unittest.TestCase):
def test_shape(self): def test_shape(self):
program_desc = core.ProgramDesc() program_desc = core.ProgramDesc()
block = program_desc.block(0) block = program_desc.block(0)
var = block.var('my_var') var = block.var(cpt.to_bytes('my_var'))
var.set_type(core.VarDesc.VarType.SELECTED_ROWS) var.set_type(core.VarDesc.VarType.SELECTED_ROWS)
src_shape = [3, 2, 10, 8] src_shape = [3, 2, 10, 8]
var.set_shape(src_shape) var.set_shape(src_shape)
...@@ -119,7 +120,7 @@ class TestVarDesc(unittest.TestCase): ...@@ -119,7 +120,7 @@ class TestVarDesc(unittest.TestCase):
def test_multiple_shape(self): def test_multiple_shape(self):
program_desc = core.ProgramDesc() program_desc = core.ProgramDesc()
block = program_desc.block(0) block = program_desc.block(0)
var = block.var('my_reader') var = block.var(cpt.to_bytes('my_reader'))
var.set_type(core.VarDesc.VarType.READER) var.set_type(core.VarDesc.VarType.READER)
src_shapes = [[2, 3, 3], [4, 5], [6, 7, 8, 9]] src_shapes = [[2, 3, 3], [4, 5], [6, 7, 8, 9]]
var.set_shapes(src_shapes) var.set_shapes(src_shapes)
...@@ -130,7 +131,7 @@ class TestVarDesc(unittest.TestCase): ...@@ -130,7 +131,7 @@ class TestVarDesc(unittest.TestCase):
def test_dtype(self): def test_dtype(self):
program_desc = core.ProgramDesc() program_desc = core.ProgramDesc()
block = program_desc.block(0) block = program_desc.block(0)
var = block.var('my_var') var = block.var(cpt.to_bytes('my_var'))
var.set_type(core.VarDesc.VarType.LOD_TENSOR) var.set_type(core.VarDesc.VarType.LOD_TENSOR)
var.set_dtype(core.VarDesc.VarType.INT32) var.set_dtype(core.VarDesc.VarType.INT32)
self.assertEqual(core.VarDesc.VarType.INT32, var.dtype()) self.assertEqual(core.VarDesc.VarType.INT32, var.dtype())
...@@ -139,7 +140,7 @@ class TestVarDesc(unittest.TestCase): ...@@ -139,7 +140,7 @@ class TestVarDesc(unittest.TestCase):
def test_multiple_dtype(self): def test_multiple_dtype(self):
program_desc = core.ProgramDesc() program_desc = core.ProgramDesc()
block = program_desc.block(0) block = program_desc.block(0)
var = block.var('my_reader') var = block.var(cpt.to_bytes('my_reader'))
var.set_type(core.VarDesc.VarType.READER) var.set_type(core.VarDesc.VarType.READER)
src_types = [ src_types = [
core.VarDesc.VarType.INT32, core.VarDesc.VarType.FP64, core.VarDesc.VarType.INT32, core.VarDesc.VarType.FP64,
...@@ -152,7 +153,7 @@ class TestVarDesc(unittest.TestCase): ...@@ -152,7 +153,7 @@ class TestVarDesc(unittest.TestCase):
def test_multiple_lod_level(self): def test_multiple_lod_level(self):
program_desc = core.ProgramDesc() program_desc = core.ProgramDesc()
block = program_desc.block(0) block = program_desc.block(0)
var = block.var('my_reader') var = block.var(cpt.to_bytes('my_reader'))
var.set_type(core.VarDesc.VarType.READER) var.set_type(core.VarDesc.VarType.READER)
src_types = [3, 1, 2] src_types = [3, 1, 2]
var.set_lod_levels(src_types) var.set_lod_levels(src_types)
...@@ -166,12 +167,12 @@ class TestBlockDesc(unittest.TestCase): ...@@ -166,12 +167,12 @@ class TestBlockDesc(unittest.TestCase):
self.assertIsNotNone(program_desc) self.assertIsNotNone(program_desc)
block = program_desc.block(0) block = program_desc.block(0)
self.assertIsNotNone(block) self.assertIsNotNone(block)
var1 = block.var("var1") var1 = block.var(cpt.to_bytes("var1"))
var2 = block.var("var2") var2 = block.var(cpt.to_bytes("var2"))
var3 = block.var("var3") var3 = block.var(cpt.to_bytes("var3"))
all_vars = block.all_vars() all_vars = block.all_vars()
self.assertEqual(set(all_vars), {var1, var2, var3}) self.assertEqual(set(all_vars), {var1, var2, var3})
var2_re = block.find_var("var2") var2_re = block.find_var(cpt.to_bytes("var2"))
self.assertEqual(var2_re, var2) self.assertEqual(var2_re, var2)
def test_add_op(self): def test_add_op(self):
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle as paddle import paddle
import numpy as np import numpy as np
import unittest import unittest
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.v2 as paddle import paddle
import paddle.v2.dataset.mnist as mnist import paddle.dataset.mnist as mnist
class TestRecordIO(unittest.TestCase): class TestRecordIO(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册