diff --git a/python/paddle/distributed/__init__.py b/python/paddle/distributed/__init__.py index 27c82227316309b370aefe5e0550230c3f703c8c..9730e9f95b6f3a2a321db82a0b7d0109099a0524 100644 --- a/python/paddle/distributed/__init__.py +++ b/python/paddle/distributed/__init__.py @@ -31,8 +31,13 @@ __all__ = ["spawn"] # dygraph parallel apis __all__ += [ - "init_parallel_env", "get_rank", "get_world_size", "prepare_context", - "ParallelEnv", "InMemoryDataset", "QueueDataset" + "init_parallel_env", + "get_rank", + "get_world_size", + "prepare_context", + "ParallelEnv", + "InMemoryDataset", + "QueueDataset", ] # collective apis diff --git a/python/paddle/distributed/fleet/__init__.py b/python/paddle/distributed/fleet/__init__.py index 2539fa57a34b1fe6fdea6b6b847d52f765df3fa3..627a155d0be19d4488a114ad2c7d1f0110f894bb 100644 --- a/python/paddle/distributed/fleet/__init__.py +++ b/python/paddle/distributed/fleet/__init__.py @@ -17,7 +17,8 @@ from .base.role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker from .base.distributed_strategy import DistributedStrategy from .base.fleet_base import Fleet from .base.util_factory import UtilBase -from .dataset import * +#from .dataset import * +from .data_generator import MultiSlotDataGenerator, MultiSlotStringDataGenerator #from . import metrics __all__ = [ @@ -26,6 +27,8 @@ __all__ = [ "UserDefinedRoleMaker", "PaddleCloudRoleMaker", "Fleet", + "MultiSlotDataGenerator", + "MultiSlotStringDataGenerator", ] fleet = Fleet() diff --git a/python/paddle/distributed/fleet/data_generator/__init__.py b/python/paddle/distributed/fleet/data_generator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..481df4064a4ecccfdfe7dc09a707b5297fabf4bc --- /dev/null +++ b/python/paddle/distributed/fleet/data_generator/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +from .data_generator import * diff --git a/python/paddle/distributed/fleet/dataset/data_generator.py b/python/paddle/distributed/fleet/data_generator/data_generator.py similarity index 99% rename from python/paddle/distributed/fleet/dataset/data_generator.py rename to python/paddle/distributed/fleet/data_generator/data_generator.py index 77c60c8dd30f8295c6172cc81aa96000b43b5106..a10c4b8fe5ce08d6dc2d5ccf096547b9a5232da3 100644 --- a/python/paddle/distributed/fleet/dataset/data_generator.py +++ b/python/paddle/distributed/fleet/data_generator/data_generator.py @@ -15,8 +15,6 @@ import os import sys -__all__ = ['MultiSlotDataGenerator', 'MultiSlotStringDataGenerator'] - class DataGenerator(object): """ diff --git a/python/paddle/distributed/fleet/dataset/test_data_generator.py b/python/paddle/distributed/fleet/data_generator/test_data_generator.py similarity index 76% rename from python/paddle/distributed/fleet/dataset/test_data_generator.py rename to python/paddle/distributed/fleet/data_generator/test_data_generator.py index 8d2e2237d6e24effac0867e33a65e9c218dc6e1a..60cbaf0bd364358de18cea5ab4297e8429581b39 100644 --- a/python/paddle/distributed/fleet/dataset/test_data_generator.py +++ b/python/paddle/distributed/fleet/data_generator/test_data_generator.py @@ -10,10 +10,11 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -from paddle.distributed.fleet.dataset import data_generator as dg +import paddle +import paddle.distributed.fleet as fleet -class SyntheticData(dg.MultiSlotDataGenerator): +class SyntheticData(fleet.MultiSlotDataGenerator): def generate_sample(self, line): def data_iter(): for i in range(10000): @@ -22,15 +23,17 @@ class SyntheticData(dg.MultiSlotDataGenerator): return data_iter -class SyntheticStringData(dg.MultiSlotStringDataGenerator): +class SyntheticStringData(fleet.MultiSlotStringDataGenerator): def generate_sample(self, line): def data_iter(): for i in range(10000): - yield ("words", ["1", "2", "3", "4"], ("label", ["0"])) + yield [("words", ["1", "2", "3", "4"]), ("label", ["0"])] + + return data_iter sd = SyntheticData() sd.run_from_memory() sd2 = SyntheticStringData() -sd.run_from_memory() +sd2.run_from_memory() diff --git a/python/paddle/distributed/fleet/dataset/__init__.py b/python/paddle/distributed/fleet/dataset/__init__.py index 055bd984c671e6796e6c89e9abaf117f82d510bd..af33c4eafb396827335157933d51f37ca8b06011 100644 --- a/python/paddle/distributed/fleet/dataset/__init__.py +++ b/python/paddle/distributed/fleet/dataset/__init__.py @@ -12,4 +12,3 @@ # See the License for the specific language governing permissions and from .dataset import * -from .data_generator import * diff --git a/python/paddle/fluid/incubate/fleet/tests/ctr_dataset_reader.py b/python/paddle/fluid/incubate/fleet/tests/ctr_dataset_reader.py index 1407e92fd386b3312f4f6f695ecc5a36849ab401..83343933074c06a914d4bd609383e4bf8867437a 100644 --- a/python/paddle/fluid/incubate/fleet/tests/ctr_dataset_reader.py +++ b/python/paddle/fluid/incubate/fleet/tests/ctr_dataset_reader.py @@ -19,7 +19,7 @@ import tarfile import os import paddle -from paddle.distributed.fleet.dataset import data_generator as data_generator +import paddle.distributed.fleet as fleet from paddle.fluid.log_helper import get_logger logger = get_logger( @@ -59,7 +59,7 @@ def load_lr_input_record(sent): return res -class DatasetCtrReader(data_generator.MultiSlotDataGenerator): +class DatasetCtrReader(fleet.MultiSlotDataGenerator): def generate_sample(self, line): def iter(): fs = line.strip().split('\t') diff --git a/python/paddle/fluid/tests/unittests/ctr_dataset_reader.py b/python/paddle/fluid/tests/unittests/ctr_dataset_reader.py index 6799d943df1bc6b0544ba724bc234721bf0b5ae6..f447be2393252b4ed758e90a8e3c4e24b03d140d 100644 --- a/python/paddle/fluid/tests/unittests/ctr_dataset_reader.py +++ b/python/paddle/fluid/tests/unittests/ctr_dataset_reader.py @@ -22,7 +22,8 @@ import random import warnings import paddle -from paddle.distributed.fleet.dataset import data_generator as data_generator +import paddle.distributed.fleet as fleet + logging.basicConfig() logger = logging.getLogger("paddle") logger.setLevel(logging.INFO) @@ -83,7 +84,7 @@ class CtrReader(object): return reader -class DatasetCtrReader(data_generator.MultiSlotDataGenerator): +class DatasetCtrReader(fleet.MultiSlotDataGenerator): def generate_sample(self, line): def get_rand(low=0.0, high=1.0): return random.random() diff --git a/python/paddle/fluid/tests/unittests/simnet_dataset_reader.py b/python/paddle/fluid/tests/unittests/simnet_dataset_reader.py index dea23ae97f3adb7b22601e42f67966d547a4ba20..737677ccf90af18f9c380e6deb0fe2437ceb1221 100644 --- a/python/paddle/fluid/tests/unittests/simnet_dataset_reader.py +++ b/python/paddle/fluid/tests/unittests/simnet_dataset_reader.py @@ -21,13 +21,13 @@ import tarfile import random import paddle -from paddle.distributed.fleet.dataset import data_generator as data_generator +import paddle.distributed.fleet as fleet logging.basicConfig() logger = logging.getLogger("paddle") logger.setLevel(logging.INFO) -class DatasetSimnetReader(data_generator.MultiSlotDataGenerator): +class DatasetSimnetReader(fleet.MultiSlotDataGenerator): def generate_sample(self, line): pass