From 7fca796eb2a7d9b57b64262c0351c51618efd710 Mon Sep 17 00:00:00 2001 From: yaoxuefeng6 Date: Fri, 18 Sep 2020 15:57:20 +0800 Subject: [PATCH] modify data_generator to fleet --- python/paddle/distributed/__init__.py | 9 +++++++-- python/paddle/distributed/fleet/__init__.py | 5 ++++- .../distributed/fleet/data_generator/__init__.py | 14 ++++++++++++++ .../{dataset => data_generator}/data_generator.py | 2 -- .../test_data_generator.py | 13 ++++++++----- .../paddle/distributed/fleet/dataset/__init__.py | 1 - .../incubate/fleet/tests/ctr_dataset_reader.py | 4 ++-- .../fluid/tests/unittests/ctr_dataset_reader.py | 5 +++-- .../fluid/tests/unittests/simnet_dataset_reader.py | 4 ++-- 9 files changed, 40 insertions(+), 17 deletions(-) create mode 100644 python/paddle/distributed/fleet/data_generator/__init__.py rename python/paddle/distributed/fleet/{dataset => data_generator}/data_generator.py (99%) rename python/paddle/distributed/fleet/{dataset => data_generator}/test_data_generator.py (76%) diff --git a/python/paddle/distributed/__init__.py b/python/paddle/distributed/__init__.py index 27c82227316..9730e9f95b6 100644 --- a/python/paddle/distributed/__init__.py +++ b/python/paddle/distributed/__init__.py @@ -31,8 +31,13 @@ __all__ = ["spawn"] # dygraph parallel apis __all__ += [ - "init_parallel_env", "get_rank", "get_world_size", "prepare_context", - "ParallelEnv", "InMemoryDataset", "QueueDataset" + "init_parallel_env", + "get_rank", + "get_world_size", + "prepare_context", + "ParallelEnv", + "InMemoryDataset", + "QueueDataset", ] # collective apis diff --git a/python/paddle/distributed/fleet/__init__.py b/python/paddle/distributed/fleet/__init__.py index 2539fa57a34..627a155d0be 100644 --- a/python/paddle/distributed/fleet/__init__.py +++ b/python/paddle/distributed/fleet/__init__.py @@ -17,7 +17,8 @@ from .base.role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker from .base.distributed_strategy import DistributedStrategy from .base.fleet_base import Fleet from .base.util_factory import UtilBase -from .dataset import * +#from .dataset import * +from .data_generator import MultiSlotDataGenerator, MultiSlotStringDataGenerator #from . import metrics __all__ = [ @@ -26,6 +27,8 @@ __all__ = [ "UserDefinedRoleMaker", "PaddleCloudRoleMaker", "Fleet", + "MultiSlotDataGenerator", + "MultiSlotStringDataGenerator", ] fleet = Fleet() diff --git a/python/paddle/distributed/fleet/data_generator/__init__.py b/python/paddle/distributed/fleet/data_generator/__init__.py new file mode 100644 index 00000000000..481df4064a4 --- /dev/null +++ b/python/paddle/distributed/fleet/data_generator/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +from .data_generator import * diff --git a/python/paddle/distributed/fleet/dataset/data_generator.py b/python/paddle/distributed/fleet/data_generator/data_generator.py similarity index 99% rename from python/paddle/distributed/fleet/dataset/data_generator.py rename to python/paddle/distributed/fleet/data_generator/data_generator.py index 77c60c8dd30..a10c4b8fe5c 100644 --- a/python/paddle/distributed/fleet/dataset/data_generator.py +++ b/python/paddle/distributed/fleet/data_generator/data_generator.py @@ -15,8 +15,6 @@ import os import sys -__all__ = ['MultiSlotDataGenerator', 'MultiSlotStringDataGenerator'] - class DataGenerator(object): """ diff --git a/python/paddle/distributed/fleet/dataset/test_data_generator.py b/python/paddle/distributed/fleet/data_generator/test_data_generator.py similarity index 76% rename from python/paddle/distributed/fleet/dataset/test_data_generator.py rename to python/paddle/distributed/fleet/data_generator/test_data_generator.py index 8d2e2237d6e..60cbaf0bd36 100644 --- a/python/paddle/distributed/fleet/dataset/test_data_generator.py +++ b/python/paddle/distributed/fleet/data_generator/test_data_generator.py @@ -10,10 +10,11 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -from paddle.distributed.fleet.dataset import data_generator as dg +import paddle +import paddle.distributed.fleet as fleet -class SyntheticData(dg.MultiSlotDataGenerator): +class SyntheticData(fleet.MultiSlotDataGenerator): def generate_sample(self, line): def data_iter(): for i in range(10000): @@ -22,15 +23,17 @@ class SyntheticData(dg.MultiSlotDataGenerator): return data_iter -class SyntheticStringData(dg.MultiSlotStringDataGenerator): +class SyntheticStringData(fleet.MultiSlotStringDataGenerator): def generate_sample(self, line): def data_iter(): for i in range(10000): - yield ("words", ["1", "2", "3", "4"], ("label", ["0"])) + yield [("words", ["1", "2", "3", "4"]), ("label", ["0"])] + + return data_iter sd = SyntheticData() sd.run_from_memory() sd2 = SyntheticStringData() -sd.run_from_memory() +sd2.run_from_memory() diff --git a/python/paddle/distributed/fleet/dataset/__init__.py b/python/paddle/distributed/fleet/dataset/__init__.py index 055bd984c67..af33c4eafb3 100644 --- a/python/paddle/distributed/fleet/dataset/__init__.py +++ b/python/paddle/distributed/fleet/dataset/__init__.py @@ -12,4 +12,3 @@ # See the License for the specific language governing permissions and from .dataset import * -from .data_generator import * diff --git a/python/paddle/fluid/incubate/fleet/tests/ctr_dataset_reader.py b/python/paddle/fluid/incubate/fleet/tests/ctr_dataset_reader.py index 1407e92fd38..83343933074 100644 --- a/python/paddle/fluid/incubate/fleet/tests/ctr_dataset_reader.py +++ b/python/paddle/fluid/incubate/fleet/tests/ctr_dataset_reader.py @@ -19,7 +19,7 @@ import tarfile import os import paddle -from paddle.distributed.fleet.dataset import data_generator as data_generator +import paddle.distributed.fleet as fleet from paddle.fluid.log_helper import get_logger logger = get_logger( @@ -59,7 +59,7 @@ def load_lr_input_record(sent): return res -class DatasetCtrReader(data_generator.MultiSlotDataGenerator): +class DatasetCtrReader(fleet.MultiSlotDataGenerator): def generate_sample(self, line): def iter(): fs = line.strip().split('\t') diff --git a/python/paddle/fluid/tests/unittests/ctr_dataset_reader.py b/python/paddle/fluid/tests/unittests/ctr_dataset_reader.py index 6799d943df1..f447be23932 100644 --- a/python/paddle/fluid/tests/unittests/ctr_dataset_reader.py +++ b/python/paddle/fluid/tests/unittests/ctr_dataset_reader.py @@ -22,7 +22,8 @@ import random import warnings import paddle -from paddle.distributed.fleet.dataset import data_generator as data_generator +import paddle.distributed.fleet as fleet + logging.basicConfig() logger = logging.getLogger("paddle") logger.setLevel(logging.INFO) @@ -83,7 +84,7 @@ class CtrReader(object): return reader -class DatasetCtrReader(data_generator.MultiSlotDataGenerator): +class DatasetCtrReader(fleet.MultiSlotDataGenerator): def generate_sample(self, line): def get_rand(low=0.0, high=1.0): return random.random() diff --git a/python/paddle/fluid/tests/unittests/simnet_dataset_reader.py b/python/paddle/fluid/tests/unittests/simnet_dataset_reader.py index dea23ae97f3..737677ccf90 100644 --- a/python/paddle/fluid/tests/unittests/simnet_dataset_reader.py +++ b/python/paddle/fluid/tests/unittests/simnet_dataset_reader.py @@ -21,13 +21,13 @@ import tarfile import random import paddle -from paddle.distributed.fleet.dataset import data_generator as data_generator +import paddle.distributed.fleet as fleet logging.basicConfig() logger = logging.getLogger("paddle") logger.setLevel(logging.INFO) -class DatasetSimnetReader(data_generator.MultiSlotDataGenerator): +class DatasetSimnetReader(fleet.MultiSlotDataGenerator): def generate_sample(self, line): pass -- GitLab