提交 7fca796e 编写于 作者: Y yaoxuefeng6

modify data_generator to fleet

上级 6ef1fbb6
......@@ -31,8 +31,13 @@ __all__ = ["spawn"]
# dygraph parallel apis
__all__ += [
"init_parallel_env", "get_rank", "get_world_size", "prepare_context",
"ParallelEnv", "InMemoryDataset", "QueueDataset"
"init_parallel_env",
"get_rank",
"get_world_size",
"prepare_context",
"ParallelEnv",
"InMemoryDataset",
"QueueDataset",
]
# collective apis
......
......@@ -17,7 +17,8 @@ from .base.role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker
from .base.distributed_strategy import DistributedStrategy
from .base.fleet_base import Fleet
from .base.util_factory import UtilBase
from .dataset import *
#from .dataset import *
from .data_generator import MultiSlotDataGenerator, MultiSlotStringDataGenerator
#from . import metrics
__all__ = [
......@@ -26,6 +27,8 @@ __all__ = [
"UserDefinedRoleMaker",
"PaddleCloudRoleMaker",
"Fleet",
"MultiSlotDataGenerator",
"MultiSlotStringDataGenerator",
]
fleet = Fleet()
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from .data_generator import *
......@@ -15,8 +15,6 @@
import os
import sys
__all__ = ['MultiSlotDataGenerator', 'MultiSlotStringDataGenerator']
class DataGenerator(object):
"""
......
......@@ -10,10 +10,11 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from paddle.distributed.fleet.dataset import data_generator as dg
import paddle
import paddle.distributed.fleet as fleet
class SyntheticData(dg.MultiSlotDataGenerator):
class SyntheticData(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(10000):
......@@ -22,15 +23,17 @@ class SyntheticData(dg.MultiSlotDataGenerator):
return data_iter
class SyntheticStringData(dg.MultiSlotStringDataGenerator):
class SyntheticStringData(fleet.MultiSlotStringDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(10000):
yield ("words", ["1", "2", "3", "4"], ("label", ["0"]))
yield [("words", ["1", "2", "3", "4"]), ("label", ["0"])]
return data_iter
sd = SyntheticData()
sd.run_from_memory()
sd2 = SyntheticStringData()
sd.run_from_memory()
sd2.run_from_memory()
......@@ -12,4 +12,3 @@
# See the License for the specific language governing permissions and
from .dataset import *
from .data_generator import *
......@@ -19,7 +19,7 @@ import tarfile
import os
import paddle
from paddle.distributed.fleet.dataset import data_generator as data_generator
import paddle.distributed.fleet as fleet
from paddle.fluid.log_helper import get_logger
logger = get_logger(
......@@ -59,7 +59,7 @@ def load_lr_input_record(sent):
return res
class DatasetCtrReader(data_generator.MultiSlotDataGenerator):
class DatasetCtrReader(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def iter():
fs = line.strip().split('\t')
......
......@@ -22,7 +22,8 @@ import random
import warnings
import paddle
from paddle.distributed.fleet.dataset import data_generator as data_generator
import paddle.distributed.fleet as fleet
logging.basicConfig()
logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)
......@@ -83,7 +84,7 @@ class CtrReader(object):
return reader
class DatasetCtrReader(data_generator.MultiSlotDataGenerator):
class DatasetCtrReader(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def get_rand(low=0.0, high=1.0):
return random.random()
......
......@@ -21,13 +21,13 @@ import tarfile
import random
import paddle
from paddle.distributed.fleet.dataset import data_generator as data_generator
import paddle.distributed.fleet as fleet
logging.basicConfig()
logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)
class DatasetSimnetReader(data_generator.MultiSlotDataGenerator):
class DatasetSimnetReader(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册