提交 7fca796e 编写于 作者: Y yaoxuefeng6

modify data_generator to fleet

上级 6ef1fbb6
...@@ -31,8 +31,13 @@ __all__ = ["spawn"] ...@@ -31,8 +31,13 @@ __all__ = ["spawn"]
# dygraph parallel apis # dygraph parallel apis
__all__ += [ __all__ += [
"init_parallel_env", "get_rank", "get_world_size", "prepare_context", "init_parallel_env",
"ParallelEnv", "InMemoryDataset", "QueueDataset" "get_rank",
"get_world_size",
"prepare_context",
"ParallelEnv",
"InMemoryDataset",
"QueueDataset",
] ]
# collective apis # collective apis
......
...@@ -17,7 +17,8 @@ from .base.role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker ...@@ -17,7 +17,8 @@ from .base.role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker
from .base.distributed_strategy import DistributedStrategy from .base.distributed_strategy import DistributedStrategy
from .base.fleet_base import Fleet from .base.fleet_base import Fleet
from .base.util_factory import UtilBase from .base.util_factory import UtilBase
from .dataset import * #from .dataset import *
from .data_generator import MultiSlotDataGenerator, MultiSlotStringDataGenerator
#from . import metrics #from . import metrics
__all__ = [ __all__ = [
...@@ -26,6 +27,8 @@ __all__ = [ ...@@ -26,6 +27,8 @@ __all__ = [
"UserDefinedRoleMaker", "UserDefinedRoleMaker",
"PaddleCloudRoleMaker", "PaddleCloudRoleMaker",
"Fleet", "Fleet",
"MultiSlotDataGenerator",
"MultiSlotStringDataGenerator",
] ]
fleet = Fleet() fleet = Fleet()
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from .data_generator import *
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
import os import os
import sys import sys
__all__ = ['MultiSlotDataGenerator', 'MultiSlotStringDataGenerator']
class DataGenerator(object): class DataGenerator(object):
""" """
......
...@@ -10,10 +10,11 @@ ...@@ -10,10 +10,11 @@
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
from paddle.distributed.fleet.dataset import data_generator as dg import paddle
import paddle.distributed.fleet as fleet
class SyntheticData(dg.MultiSlotDataGenerator): class SyntheticData(fleet.MultiSlotDataGenerator):
def generate_sample(self, line): def generate_sample(self, line):
def data_iter(): def data_iter():
for i in range(10000): for i in range(10000):
...@@ -22,15 +23,17 @@ class SyntheticData(dg.MultiSlotDataGenerator): ...@@ -22,15 +23,17 @@ class SyntheticData(dg.MultiSlotDataGenerator):
return data_iter return data_iter
class SyntheticStringData(dg.MultiSlotStringDataGenerator): class SyntheticStringData(fleet.MultiSlotStringDataGenerator):
def generate_sample(self, line): def generate_sample(self, line):
def data_iter(): def data_iter():
for i in range(10000): for i in range(10000):
yield ("words", ["1", "2", "3", "4"], ("label", ["0"])) yield [("words", ["1", "2", "3", "4"]), ("label", ["0"])]
return data_iter
sd = SyntheticData() sd = SyntheticData()
sd.run_from_memory() sd.run_from_memory()
sd2 = SyntheticStringData() sd2 = SyntheticStringData()
sd.run_from_memory() sd2.run_from_memory()
...@@ -12,4 +12,3 @@ ...@@ -12,4 +12,3 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
from .dataset import * from .dataset import *
from .data_generator import *
...@@ -19,7 +19,7 @@ import tarfile ...@@ -19,7 +19,7 @@ import tarfile
import os import os
import paddle import paddle
from paddle.distributed.fleet.dataset import data_generator as data_generator import paddle.distributed.fleet as fleet
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
logger = get_logger( logger = get_logger(
...@@ -59,7 +59,7 @@ def load_lr_input_record(sent): ...@@ -59,7 +59,7 @@ def load_lr_input_record(sent):
return res return res
class DatasetCtrReader(data_generator.MultiSlotDataGenerator): class DatasetCtrReader(fleet.MultiSlotDataGenerator):
def generate_sample(self, line): def generate_sample(self, line):
def iter(): def iter():
fs = line.strip().split('\t') fs = line.strip().split('\t')
......
...@@ -22,7 +22,8 @@ import random ...@@ -22,7 +22,8 @@ import random
import warnings import warnings
import paddle import paddle
from paddle.distributed.fleet.dataset import data_generator as data_generator import paddle.distributed.fleet as fleet
logging.basicConfig() logging.basicConfig()
logger = logging.getLogger("paddle") logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
...@@ -83,7 +84,7 @@ class CtrReader(object): ...@@ -83,7 +84,7 @@ class CtrReader(object):
return reader return reader
class DatasetCtrReader(data_generator.MultiSlotDataGenerator): class DatasetCtrReader(fleet.MultiSlotDataGenerator):
def generate_sample(self, line): def generate_sample(self, line):
def get_rand(low=0.0, high=1.0): def get_rand(low=0.0, high=1.0):
return random.random() return random.random()
......
...@@ -21,13 +21,13 @@ import tarfile ...@@ -21,13 +21,13 @@ import tarfile
import random import random
import paddle import paddle
from paddle.distributed.fleet.dataset import data_generator as data_generator import paddle.distributed.fleet as fleet
logging.basicConfig() logging.basicConfig()
logger = logging.getLogger("paddle") logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
class DatasetSimnetReader(data_generator.MultiSlotDataGenerator): class DatasetSimnetReader(fleet.MultiSlotDataGenerator):
def generate_sample(self, line): def generate_sample(self, line):
pass pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册