diff --git a/python/paddle/trainer/PyDataProvider2.py b/python/paddle/trainer/PyDataProvider2.py index f3e3fbf48362167e4f92f978885f74be38fe2a96..dfa7496cf5aa5410739c525ecb9258b2eb85cbab 100644 --- a/python/paddle/trainer/PyDataProvider2.py +++ b/python/paddle/trainer/PyDataProvider2.py @@ -106,9 +106,7 @@ def integer_value_sequence(dim): def integer_value_sub_sequence(dim): return integer_value(dim, seq_type=SequenceType.SUB_SEQUENCE) - -def integer_sequence(dim): - return index_slot(dim, seq_type=SequenceType.SEQUENCE) +integer_sequence = integer_value_sequence class SingleSlotWrapper(object): diff --git a/python/paddle/trainer_config_helpers/data_sources.py b/python/paddle/trainer_config_helpers/data_sources.py index c62553f54cc304259015f1e8e6733eff99f476f2..fc72014c91eda6d399f717a7feda6334f83741e5 100644 --- a/python/paddle/trainer_config_helpers/data_sources.py +++ b/python/paddle/trainer_config_helpers/data_sources.py @@ -78,21 +78,6 @@ def define_py_data_source(file_list, if not isinstance(args, basestring) and args is not None: args = pickle.dumps(args, 0) - if data_cls is None: - - def py_data2(files, load_data_module, load_data_object, load_data_args, - **kwargs): - data = DataBase() - data.type = 'py2' - data.files = files - data.load_data_module = load_data_module - data.load_data_object = load_data_object - data.load_data_args = load_data_args - data.async_load_data = True - return data - - data_cls = py_data2 - cls( data_cls( files=file_list, @@ -207,10 +192,21 @@ def define_py_data_sources2(train_list, test_list, module, obj, args=None): :return: None :rtype: None """ + def py_data2(files, load_data_module, load_data_object, load_data_args, + **kwargs): + data = DataBase() + data.type = 'py2' + data.files = files + data.load_data_module = load_data_module + data.load_data_object = load_data_object + data.load_data_args = load_data_args + data.async_load_data = True + return data + define_py_data_sources( train_list=train_list, test_list=test_list, module=module, obj=obj, args=args, - data_cls=None) + data_cls=py_data2)