diff --git a/fleetrec/models/ctr_dnn/data_generator.py b/fleetrec/models/ctr_dnn/data_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..3225cbd1b358b528b551ac21f62e0a56ae4bc3ab --- /dev/null +++ b/fleetrec/models/ctr_dnn/data_generator.py @@ -0,0 +1,65 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function +from abc import ABC + +from fleetrec.reader.reader import Reader +from fleetrec.utils import envs + + +class TrainReader(Reader): + def init(self): + self.cont_min_ = [0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + self.cont_max_ = [20, 600, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50] + self.cont_diff_ = [20, 603, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50] + self.hash_dim_ = envs.get_global_env("hyper_parameters.sparse_feature_number", None, "train.model") + self.continuous_range_ = range(1, 14) + self.categorical_range_ = range(14, 40) + + def generate_sample(self, line): + """ + Read the data line by line and process it as a dictionary + """ + + def reader(): + """ + This function needs to be implemented by the user, based on data format + """ + features = line.rstrip('\n').split('\t') + + dense_feature = [] + sparse_feature = [] + for idx in self.continuous_range_: + if features[idx] == "": + dense_feature.append(0.0) + else: + dense_feature.append( + (float(features[idx]) - self.cont_min_[idx - 1]) / + self.cont_diff_[idx - 1]) + + for idx in self.categorical_range_: + sparse_feature.append( + [hash(str(idx) + features[idx]) % self.hash_dim_]) + label = [int(features[0])] + feature_name = ["dense_input"] + for idx in self.categorical_range_: + feature_name.append("C" + str(idx - 13)) + feature_name.append("label") + yield zip(feature_name, [dense_feature] + sparse_feature + [label]) + + return reader + + +class EvaluateReader(Reader, ABC): + pass diff --git a/fleetrec/reader/reader.py b/fleetrec/reader/reader.py new file mode 100644 index 0000000000000000000000000000000000000000..8e696747f99f63ae2f3323ea851323f1c61d50f1 --- /dev/null +++ b/fleetrec/reader/reader.py @@ -0,0 +1,30 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +import abc + +import paddle.fluid.incubate.data_generator as dg + + +class Reader(dg.MultiSlotDataGenerator): + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def init(self): + pass + + @abc.abstractmethod + def generate_sample(self, line): + pass diff --git a/fleetrec/reader/reader_instance.py b/fleetrec/reader/reader_instance.py new file mode 100644 index 0000000000000000000000000000000000000000..2b771794c23d0eaad22f436358ba72027f250e1a --- /dev/null +++ b/fleetrec/reader/reader_instance.py @@ -0,0 +1,31 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function +import sys + +from fleetrec.utils.envs import lazy_instance + +if len(sys.argv) != 3: + raise ValueError("reader only accept two argument: initialized reader class name and TRAIN/EVALUATE") + +reader_package = sys.argv[1] + +if sys.argv[2] == "TRAIN": + reader_name = "TrainReader" +else: + reader_name = "EvaluateReader" + +reader_class = lazy_instance(reader_package, reader_name) +reader = reader_class() +reader.run_from_stdin() diff --git a/fleetrec/trainer/transpiler_trainer.py b/fleetrec/trainer/transpiler_trainer.py index f7fa8e679d27b9480479e5e7f9b4200f76855151..fa568851986e02585f3f556c9f204247116ee1db 100644 --- a/fleetrec/trainer/transpiler_trainer.py +++ b/fleetrec/trainer/transpiler_trainer.py @@ -21,8 +21,8 @@ import paddle.fluid as fluid from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet -from .trainer import Trainer -from ..utils import envs +from fleetrec.trainer import Trainer +from fleetrec.utils import envs class TranspileTrainer(Trainer): @@ -113,9 +113,8 @@ class TranspileTrainer(Trainer): def instance(self, context): models = envs.get_global_env("train.model.models") - model_package = __import__(models, globals(), locals(), models.split(".")) - train_model = getattr(model_package, 'Train') - self.model = train_model(None) + model_class = envs.lazy_instance(models, "TrainNet") + self.model = model_class(None) context['status'] = 'init_pass' def init(self, context): diff --git a/fleetrec/utils/envs.py b/fleetrec/utils/envs.py index 95cf444741c60a88d2d50616b6d9b7097e80d0c9..2cd55dd47d251ee9b9f0b3229595d5284369da6d 100644 --- a/fleetrec/utils/envs.py +++ b/fleetrec/utils/envs.py @@ -81,3 +81,11 @@ def pretty_print_envs(envs, header=None): _str = "\n{}\n".format(draws) return _str + + +def lazy_instance(package, class_name): + models = get_global_env("train.model.models") + model_package = __import__(package, globals(), locals(), package.split(".")) + instance = getattr(model_package, class_name) + return instance +