From 62b20ca033568212981c5836aecb3fb1d025c3f2 Mon Sep 17 00:00:00 2001 From: wangyanfei01 Date: Tue, 13 Dec 2016 16:52:21 +0800 Subject: [PATCH] refine data_sources.py and PyDataProvider2.py to make more readable --- python/paddle/trainer/PyDataProvider2.py | 4 +-- .../trainer_config_helpers/data_sources.py | 28 ++++++++----------- 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/python/paddle/trainer/PyDataProvider2.py b/python/paddle/trainer/PyDataProvider2.py index f3e3fbf4836..dfa7496cf5a 100644 --- a/python/paddle/trainer/PyDataProvider2.py +++ b/python/paddle/trainer/PyDataProvider2.py @@ -106,9 +106,7 @@ def integer_value_sequence(dim): def integer_value_sub_sequence(dim): return integer_value(dim, seq_type=SequenceType.SUB_SEQUENCE) - -def integer_sequence(dim): - return index_slot(dim, seq_type=SequenceType.SEQUENCE) +integer_sequence = integer_value_sequence class SingleSlotWrapper(object): diff --git a/python/paddle/trainer_config_helpers/data_sources.py b/python/paddle/trainer_config_helpers/data_sources.py index c62553f54cc..fc72014c91e 100644 --- a/python/paddle/trainer_config_helpers/data_sources.py +++ b/python/paddle/trainer_config_helpers/data_sources.py @@ -78,21 +78,6 @@ def define_py_data_source(file_list, if not isinstance(args, basestring) and args is not None: args = pickle.dumps(args, 0) - if data_cls is None: - - def py_data2(files, load_data_module, load_data_object, load_data_args, - **kwargs): - data = DataBase() - data.type = 'py2' - data.files = files - data.load_data_module = load_data_module - data.load_data_object = load_data_object - data.load_data_args = load_data_args - data.async_load_data = True - return data - - data_cls = py_data2 - cls( data_cls( files=file_list, @@ -207,10 +192,21 @@ def define_py_data_sources2(train_list, test_list, module, obj, args=None): :return: None :rtype: None """ + def py_data2(files, load_data_module, load_data_object, load_data_args, + **kwargs): + data = DataBase() + data.type = 'py2' + data.files = files + data.load_data_module = load_data_module + data.load_data_object = load_data_object + data.load_data_args = load_data_args + data.async_load_data = True + return data + define_py_data_sources( train_list=train_list, test_list=test_list, module=module, obj=obj, args=args, - data_cls=None) + data_cls=py_data2) -- GitLab