提交 357f0da7 编写于 作者: T tangwei

add reader implement

上级 cf82361a
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from abc import ABC
from fleetrec.reader.reader import Reader
from fleetrec.utils import envs
class TrainReader(Reader):
def init(self):
self.cont_min_ = [0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
self.cont_max_ = [20, 600, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50]
self.cont_diff_ = [20, 603, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50]
self.hash_dim_ = envs.get_global_env("hyper_parameters.sparse_feature_number", None, "train.model")
self.continuous_range_ = range(1, 14)
self.categorical_range_ = range(14, 40)
def generate_sample(self, line):
"""
Read the data line by line and process it as a dictionary
"""
def reader():
"""
This function needs to be implemented by the user, based on data format
"""
features = line.rstrip('\n').split('\t')
dense_feature = []
sparse_feature = []
for idx in self.continuous_range_:
if features[idx] == "":
dense_feature.append(0.0)
else:
dense_feature.append(
(float(features[idx]) - self.cont_min_[idx - 1]) /
self.cont_diff_[idx - 1])
for idx in self.categorical_range_:
sparse_feature.append(
[hash(str(idx) + features[idx]) % self.hash_dim_])
label = [int(features[0])]
feature_name = ["dense_input"]
for idx in self.categorical_range_:
feature_name.append("C" + str(idx - 13))
feature_name.append("label")
yield zip(feature_name, [dense_feature] + sparse_feature + [label])
return reader
class EvaluateReader(Reader, ABC):
pass
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import abc
import paddle.fluid.incubate.data_generator as dg
class Reader(dg.MultiSlotDataGenerator):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def init(self):
pass
@abc.abstractmethod
def generate_sample(self, line):
pass
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
from fleetrec.utils.envs import lazy_instance
if len(sys.argv) != 3:
raise ValueError("reader only accept two argument: initialized reader class name and TRAIN/EVALUATE")
reader_package = sys.argv[1]
if sys.argv[2] == "TRAIN":
reader_name = "TrainReader"
else:
reader_name = "EvaluateReader"
reader_class = lazy_instance(reader_package, reader_name)
reader = reader_class()
reader.run_from_stdin()
......@@ -21,8 +21,8 @@ import paddle.fluid as fluid
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from .trainer import Trainer
from ..utils import envs
from fleetrec.trainer import Trainer
from fleetrec.utils import envs
class TranspileTrainer(Trainer):
......@@ -113,9 +113,8 @@ class TranspileTrainer(Trainer):
def instance(self, context):
models = envs.get_global_env("train.model.models")
model_package = __import__(models, globals(), locals(), models.split("."))
train_model = getattr(model_package, 'Train')
self.model = train_model(None)
model_class = envs.lazy_instance(models, "TrainNet")
self.model = model_class(None)
context['status'] = 'init_pass'
def init(self, context):
......
......@@ -81,3 +81,11 @@ def pretty_print_envs(envs, header=None):
_str = "\n{}\n".format(draws)
return _str
def lazy_instance(package, class_name):
models = get_global_env("train.model.models")
model_package = __import__(package, globals(), locals(), package.split("."))
instance = getattr(model_package, class_name)
return instance
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册