提交 987e86a6 编写于 作者: T tangwei12

fix eleps package

上级 8255fba7
...@@ -41,11 +41,14 @@ train: ...@@ -41,11 +41,14 @@ train:
model: model:
models: "eleps.models.ctr_dnn.model.py" models: "eleps.models.ctr_dnn.model.py"
hyper_parameters: hyper_parameters:
sparse_inputs_slots: 27, sparse_inputs_slots: 27
sparse_feature_number: 1000001, sparse_feature_number: 1000001
sparse_feature_dim: 8, sparse_feature_dim: 8
dense_input_dim: 13, dense_input_dim: 13
fc_sizes: [1024, 512, 32], fc_sizes: [101, 512, 32]
# - 1024
# - 512
# - 32
learning_rate: 0.001 learning_rate: 0.001
save: save:
......
...@@ -31,12 +31,10 @@ from eleps.trainer.factory import TrainerFactory ...@@ -31,12 +31,10 @@ from eleps.trainer.factory import TrainerFactory
if __name__ == "__main__": if __name__ == "__main__":
with open('ctr-dnn_train.yaml', 'r') as rb: abs_dir = os.path.dirname(os.path.abspath(__file__))
global_config = yaml.load(rb.read())
print global_config with open(os.path.join(abs_dir, 'ctr-dnn_train.yaml'), 'r') as rb:
global_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
os.exit()
trainer = TrainerFactory.create(global_config) trainer = TrainerFactory.create(global_config)
trainer.run() trainer.run()
...@@ -27,15 +27,15 @@ ...@@ -27,15 +27,15 @@
import os import os
import yaml import yaml
from .single_train import SingleTrainerWithDataloader from eleps.trainer.single_train import SingleTrainerWithDataloader
from .single_train import SingleTrainerWithDataset from eleps.trainer.single_train import SingleTrainerWithDataset
from .cluster_train import ClusterTrainerWithDataloader from eleps.trainer.cluster_train import ClusterTrainerWithDataloader
from .cluster_train import ClusterTrainerWithDataset from eleps.trainer.cluster_train import ClusterTrainerWithDataset
from .ctr_trainer import CtrPaddleTrainer from eleps.trainer.ctr_trainer import CtrPaddleTrainer
from ..utils import envs from eleps.utils import envs
class TrainerFactory(object): class TrainerFactory(object):
...@@ -83,3 +83,4 @@ class TrainerFactory(object): ...@@ -83,3 +83,4 @@ class TrainerFactory(object):
trainer = TrainerFactory._build_trainer(_config) trainer = TrainerFactory._build_trainer(_config)
return trainer return trainer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册