提交 3f3feb83 编写于 作者: T tangwei12

bug fix

上级 52e6010e
...@@ -46,7 +46,7 @@ class TrainerFactory(object): ...@@ -46,7 +46,7 @@ class TrainerFactory(object):
if trainer_abs is None: if trainer_abs is None:
if not os.path.isfile(train_mode): if not os.path.isfile(train_mode):
raise FileNotFoundError("trainer {} can not be recognized".format(train_mode)) raise IOError("trainer {} can not be recognized".format(train_mode))
trainer_abs = train_mode trainer_abs = train_mode
train_mode = "UserDefineTrainer" train_mode = "UserDefineTrainer"
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
train: train:
trainer: trainer:
trainer: "fleetrec/examples/user_define_trainer.py" trainer: "fleetrec/demo/user_define_trainer.py"
threads: 4 threads: 4
# for cluster training # for cluster training
strategy: "async" strategy: "async"
...@@ -29,8 +29,8 @@ train: ...@@ -29,8 +29,8 @@ train:
reader: reader:
mode: "dataset" mode: "dataset"
batch_size: 2 batch_size: 2
class: "fleetrec.models.ctr_dnn.reader" class: "fleetrec.models.ctr.criteo_reader"
train_data_path: "fleetrec/models/ctr_dnn/data/train/" train_data_path: "fleetrec/models/ctr/dnn/data/train"
model: model:
models: "fleetrec.models.ctr_dnn.model" models: "fleetrec.models.ctr_dnn.model"
......
...@@ -172,9 +172,10 @@ if __name__ == "__main__": ...@@ -172,9 +172,10 @@ if __name__ == "__main__":
args.device = args.device.upper() args.device = args.device.upper()
if not os.path.isfile(args.model): if not os.path.isfile(args.model):
raise FileNotFoundError("argument model: {} do not exist".format(args.model)) raise IOError("argument model: {} do not exist".format(args.model))
engine_registry() engine_registry()
which_engine = get_engine(args.engine, args.device) which_engine = get_engine(args.engine, args.device)
engine = which_engine(args) engine = which_engine(args)
engine.run() engine.run()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册