diff --git a/examples/ctr-dnn_train.yaml b/examples/ctr-dnn_train.yaml index 2baf63f2ac9a609a1ac84ea9b366dc184674f2b0..98a533cb299dfdfc7fadac587f9db09ba11508c4 100644 --- a/examples/ctr-dnn_train.yaml +++ b/examples/ctr-dnn_train.yaml @@ -41,11 +41,14 @@ train: model: models: "eleps.models.ctr_dnn.model.py" hyper_parameters: - sparse_inputs_slots: 27, - sparse_feature_number: 1000001, - sparse_feature_dim: 8, - dense_input_dim: 13, - fc_sizes: [1024, 512, 32], + sparse_inputs_slots: 27 + sparse_feature_number: 1000001 + sparse_feature_dim: 8 + dense_input_dim: 13 + fc_sizes: [101, 512, 32] + # - 1024 + # - 512 + # - 32 learning_rate: 0.001 save: diff --git a/examples/train.py b/examples/train.py index c91b2f45c73ed79e7c2ec181db05eaf50e819005..d4943d8d2251a56ab852d0040bcc71821277aa76 100644 --- a/examples/train.py +++ b/examples/train.py @@ -31,12 +31,10 @@ from eleps.trainer.factory import TrainerFactory if __name__ == "__main__": - with open('ctr-dnn_train.yaml', 'r') as rb: - global_config = yaml.load(rb.read()) + abs_dir = os.path.dirname(os.path.abspath(__file__)) - print global_config - - os.exit() + with open(os.path.join(abs_dir, 'ctr-dnn_train.yaml'), 'r') as rb: + global_config = yaml.load(rb.read(), Loader=yaml.FullLoader) trainer = TrainerFactory.create(global_config) trainer.run() diff --git a/trainer/factory.py b/trainer/factory.py index b96d244365bf6ff4e2703ccf6b73fb572dfe986a..393a2d44b6b1ebb4495a857c06d1b87e59384dc6 100644 --- a/trainer/factory.py +++ b/trainer/factory.py @@ -27,15 +27,15 @@ import os import yaml -from .single_train import SingleTrainerWithDataloader -from .single_train import SingleTrainerWithDataset +from eleps.trainer.single_train import SingleTrainerWithDataloader +from eleps.trainer.single_train import SingleTrainerWithDataset -from .cluster_train import ClusterTrainerWithDataloader -from .cluster_train import ClusterTrainerWithDataset +from eleps.trainer.cluster_train import ClusterTrainerWithDataloader +from eleps.trainer.cluster_train import ClusterTrainerWithDataset -from .ctr_trainer import CtrPaddleTrainer +from eleps.trainer.ctr_trainer import CtrPaddleTrainer -from ..utils import envs +from eleps.utils import envs class TrainerFactory(object): @@ -83,3 +83,4 @@ class TrainerFactory(object): trainer = TrainerFactory._build_trainer(_config) return trainer +