From 987e86a69fe801ad571e3be8005e958a3dc494b2 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 2 Apr 2020 02:46:04 +0000 Subject: [PATCH] fix eleps package --- examples/ctr-dnn_train.yaml | 13 ++++++++----- examples/train.py | 8 +++----- trainer/factory.py | 13 +++++++------ 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/examples/ctr-dnn_train.yaml b/examples/ctr-dnn_train.yaml index 2baf63f2..98a533cb 100644 --- a/examples/ctr-dnn_train.yaml +++ b/examples/ctr-dnn_train.yaml @@ -41,11 +41,14 @@ train: model: models: "eleps.models.ctr_dnn.model.py" hyper_parameters: - sparse_inputs_slots: 27, - sparse_feature_number: 1000001, - sparse_feature_dim: 8, - dense_input_dim: 13, - fc_sizes: [1024, 512, 32], + sparse_inputs_slots: 27 + sparse_feature_number: 1000001 + sparse_feature_dim: 8 + dense_input_dim: 13 + fc_sizes: [101, 512, 32] + # - 1024 + # - 512 + # - 32 learning_rate: 0.001 save: diff --git a/examples/train.py b/examples/train.py index c91b2f45..d4943d8d 100644 --- a/examples/train.py +++ b/examples/train.py @@ -31,12 +31,10 @@ from eleps.trainer.factory import TrainerFactory if __name__ == "__main__": - with open('ctr-dnn_train.yaml', 'r') as rb: - global_config = yaml.load(rb.read()) + abs_dir = os.path.dirname(os.path.abspath(__file__)) - print global_config - - os.exit() + with open(os.path.join(abs_dir, 'ctr-dnn_train.yaml'), 'r') as rb: + global_config = yaml.load(rb.read(), Loader=yaml.FullLoader) trainer = TrainerFactory.create(global_config) trainer.run() diff --git a/trainer/factory.py b/trainer/factory.py index b96d2443..393a2d44 100644 --- a/trainer/factory.py +++ b/trainer/factory.py @@ -27,15 +27,15 @@ import os import yaml -from .single_train import SingleTrainerWithDataloader -from .single_train import SingleTrainerWithDataset +from eleps.trainer.single_train import SingleTrainerWithDataloader +from eleps.trainer.single_train import SingleTrainerWithDataset -from .cluster_train import ClusterTrainerWithDataloader -from .cluster_train import ClusterTrainerWithDataset +from eleps.trainer.cluster_train import ClusterTrainerWithDataloader +from eleps.trainer.cluster_train import ClusterTrainerWithDataset -from .ctr_trainer import CtrPaddleTrainer +from eleps.trainer.ctr_trainer import CtrPaddleTrainer -from ..utils import envs +from eleps.utils import envs class TrainerFactory(object): @@ -83,3 +83,4 @@ class TrainerFactory(object): trainer = TrainerFactory._build_trainer(_config) return trainer + -- GitLab