diff --git a/fleetrec/core/factory.py b/fleetrec/core/factory.py index 1d6800e1154a04845bdc5443e86a7d8da61c57cb..631bb1f6bfbc9603fd796cdba12f573050994f54 100644 --- a/fleetrec/core/factory.py +++ b/fleetrec/core/factory.py @@ -46,7 +46,7 @@ class TrainerFactory(object): if trainer_abs is None: if not os.path.isfile(train_mode): - raise FileNotFoundError("trainer {} can not be recognized".format(train_mode)) + raise IOError("trainer {} can not be recognized".format(train_mode)) trainer_abs = train_mode train_mode = "UserDefineTrainer" diff --git a/fleetrec/demo/ctr-dnn_train.yaml b/fleetrec/demo/ctr-dnn_train.yaml index 591d1fb3c99e59bf145f1bfbd63d5083396cc9c1..558fafcf631b42b9ddaf4dbc5cfafa6fe5fbcae4 100644 --- a/fleetrec/demo/ctr-dnn_train.yaml +++ b/fleetrec/demo/ctr-dnn_train.yaml @@ -14,7 +14,7 @@ train: trainer: - trainer: "fleetrec/examples/user_define_trainer.py" + trainer: "fleetrec/demo/user_define_trainer.py" threads: 4 # for cluster training strategy: "async" @@ -29,8 +29,8 @@ train: reader: mode: "dataset" batch_size: 2 - class: "fleetrec.models.ctr_dnn.reader" - train_data_path: "fleetrec/models/ctr_dnn/data/train/" + class: "fleetrec.models.ctr.criteo_reader" + train_data_path: "fleetrec/models/ctr/dnn/data/train" model: models: "fleetrec.models.ctr_dnn.model" diff --git a/fleetrec/run.py b/fleetrec/run.py index 9a4645a7616aed9fb56c801f6396a3dcee019224..0bdc49b83c74aa868d10d46ea397d9594f4a7e97 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -172,9 +172,10 @@ if __name__ == "__main__": args.device = args.device.upper() if not os.path.isfile(args.model): - raise FileNotFoundError("argument model: {} do not exist".format(args.model)) + raise IOError("argument model: {} do not exist".format(args.model)) engine_registry() which_engine = get_engine(args.engine, args.device) + engine = which_engine(args) engine.run()