diff --git a/fleetrec/core/factory.py b/fleetrec/core/factory.py index 6bf8a176161e33bbec6430f628fdf71081031ecb..c2e02bf5e331b067639e58079bfae3091331f017 100644 --- a/fleetrec/core/factory.py +++ b/fleetrec/core/factory.py @@ -28,10 +28,7 @@ class TrainerFactory(object): def _build_trainer(yaml_path): print(envs.pretty_print_envs(envs.get_global_envs())) - train_mode = envs.get_global_env("train.trainer") - - if train_mode is None: - train_mode = envs.get_runtime_envion("train.trainer") + train_mode = envs.get_training_mode() if train_mode == "SingleTraining": from fleetrec.core.trainers.single_trainer import SingleTrainer diff --git a/fleetrec/core/trainers/ctr_coding_trainer.py b/fleetrec/core/trainers/ctr_coding_trainer.py index 7ba3bec71b260acd391ada55df708eb57c22c08a..34f8fa44e8eecb960bb01445397f0184c56cd3e5 100755 --- a/fleetrec/core/trainers/ctr_coding_trainer.py +++ b/fleetrec/core/trainers/ctr_coding_trainer.py @@ -12,10 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import sys -import time -import json -import datetime import numpy as np import paddle.fluid as fluid diff --git a/fleetrec/core/utils/envs.py b/fleetrec/core/utils/envs.py index 79383172ca6dea8bccddba0e27a9d792d3ae7876..fc5228f01e9caaac2daf494f2bfde417393570c0 100644 --- a/fleetrec/core/utils/envs.py +++ b/fleetrec/core/utils/envs.py @@ -29,6 +29,14 @@ def get_runtime_envion(key): return os.getenv(key, None) +def get_training_mode(): + train_mode = get_global_env("train.trainer") + + if train_mode is None: + train_mode = get_runtime_envion("train.trainer") + return train_mode + + def set_global_envs(envs): assert isinstance(envs, dict) diff --git a/fleetrec/examples/__init__.py b/fleetrec/examples/__init__.py index 4d1d1a6d8a36d3131eb9bd06d14d3faf4e9dd5e0..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/fleetrec/examples/__init__.py +++ b/fleetrec/examples/__init__.py @@ -1,18 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -examples for user -""" \ No newline at end of file diff --git a/fleetrec/examples/build_in/__init__.py b/fleetrec/examples/build_in/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/fleetrec/examples/build_in/cluster_training_mpi.yaml b/fleetrec/examples/build_in/cluster_training_mpi.yaml deleted file mode 100644 index 19974f40ad5256b65bbf3c6adcb1bf6217756bd1..0000000000000000000000000000000000000000 --- a/fleetrec/examples/build_in/cluster_training_mpi.yaml +++ /dev/null @@ -1,10 +0,0 @@ - -trainer: "MPIClusterTraining" - -pserver_num: 2 -trainer_num: 2 -start_port: 36001 -log_dirname: "logs" - -strategy: - mode: "async" diff --git a/fleetrec/examples/build_in/ctr-dnn_train.yaml b/fleetrec/examples/ctr-dnn_train.yaml similarity index 98% rename from fleetrec/examples/build_in/ctr-dnn_train.yaml rename to fleetrec/examples/ctr-dnn_train.yaml index 87de5fae0b41562f0efc0853767e39c0d0f07ae7..bbf63cd0506a1966c927e10e5dc7281f8b66de0b 100644 --- a/fleetrec/examples/build_in/ctr-dnn_train.yaml +++ b/fleetrec/examples/ctr-dnn_train.yaml @@ -27,7 +27,7 @@ train: hyper_parameters: sparse_inputs_slots: 27 sparse_feature_number: 1000001 - sparse_feature_dim: 8 + sparse_feature_dim: 9 dense_input_dim: 13 fc_sizes: [512, 256, 128, 32] learning_rate: 0.001 diff --git a/fleetrec/examples/user_define/__init__.py b/fleetrec/examples/user_define/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/fleetrec/examples/build_in/user_define_trainer.py b/fleetrec/examples/user_define/user_define_trainer.py similarity index 100% rename from fleetrec/examples/build_in/user_define_trainer.py rename to fleetrec/examples/user_define/user_define_trainer.py diff --git a/fleetrec/examples/build_in/cluster_training_user_define.yaml b/fleetrec/examples/user_define/user_define_trainer.yaml similarity index 100% rename from fleetrec/examples/build_in/cluster_training_user_define.yaml rename to fleetrec/examples/user_define/user_define_trainer.yaml diff --git a/fleetrec/models/ctr_dnn/model.py b/fleetrec/models/ctr_dnn/model.py index 970ef6facbf4d649fbd3dffa0fd9962820df34ce..a025b8a15237e7f401a8a7293d03f55497d4ea06 100644 --- a/fleetrec/models/ctr_dnn/model.py +++ b/fleetrec/models/ctr_dnn/model.py @@ -60,13 +60,18 @@ class Model(ModelBase): self._data_var.append(self.label_input) def net(self): - def embedding_layer(input): - sparse_feature_number = envs.get_global_env("hyper_parameters.sparse_feature_number", None, self.namespace) - sparse_feature_dim = envs.get_global_env("hyper_parameters.sparse_feature_dim", None, self.namespace) + train_mode = envs.get_training_mode() + + is_distributed = True if train_mode == "CtrTraining" else False + sparse_feature_number = envs.get_global_env("hyper_parameters.sparse_feature_number", None, self.namespace) + sparse_feature_dim = envs.get_global_env("hyper_parameters.sparse_feature_dim", None, self.namespace) + sparse_feature_dim = 9 if train_mode == "CtrTraining" else sparse_feature_dim + def embedding_layer(input): emb = fluid.layers.embedding( input=input, is_sparse=True, + is_distributed=is_distributed, size=[sparse_feature_number, sparse_feature_dim], param_attr=fluid.ParamAttr( name="SparseFeatFactors", diff --git a/fleetrec/run.py b/fleetrec/run.py index a3cdc2175be597924a48ed027beea153952b7b57..3ea52c7fadd429a818ecbb092f9fffa1c8d6d3ef 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -103,6 +103,15 @@ if __name__ == "__main__": local_mpi_engine(cluster_envs, args.model) elif args.engine.upper() == "CLUSTER": print("launch ClusterTraining with cluster to run model: {}".format(args.model)) + + if version.is_transpiler(): + print("use ClusterTraining to run model: {}".format(args.model)) + cluster_envs = {"train.trainer": "ClusterTraining"} + envs.set_runtime_envions(cluster_envs) + else: + cluster_envs = {"train.trainer": "CtrTraining"} + envs.set_runtime_envions(cluster_envs) + run(args.model) elif args.engine.upper() == "USER_DEFINE": engine_file = args.engine_extras