diff --git a/fleetrec/run.py b/fleetrec/run.py index 641b2ce47c2c462e684b38154d817f2b3a7f237a..9a4645a7616aed9fb56c801f6396a3dcee019224 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -164,10 +164,12 @@ def local_mpi_engine(args): if __name__ == "__main__": parser = argparse.ArgumentParser(description='fleet-rec run') parser.add_argument("-m", "--model", type=str) - parser.add_argument("-e", "--engine", type=str, choices=clusters) - parser.add_argument("-d", "--device", type=str, choices=["CPU", "GPU"], default="CPU") + parser.add_argument("-e", "--engine", type=str, choices=["single", "local_cluster", "cluster"]) + parser.add_argument("-d", "--device", type=str, choices=["cpu", "gpu"], default="cpu") args = parser.parse_args() + args.engine = args.engine.upper() + args.device = args.device.upper() if not os.path.isfile(args.model): raise FileNotFoundError("argument model: {} do not exist".format(args.model)) diff --git a/readme.md b/readme.md index 5beb6ea63254ef9024dfe798f871bf7d1f3a3fde..1565b367a6046cc9d1b02ce78d92469fc444ee9c 100644 --- a/readme.md +++ b/readme.md @@ -28,17 +28,25 @@ cd FleetRec python -m fleetrec.run \ -m fleetrec/examples/ctr-dnn_train.yaml \ + -d cpu \ -e single + +# 使用GPU资源进行训练 +python -m fleetrec.run \ + -m fleetrec/examples/ctr-dnn_train.yaml \ + -d gpu \ + -e single ``` ### 本地模拟分布式训练 ```bash cd FleetRec - +# 使用CPU资源进行训练 python -m fleetrec.run \ -m fleetrec/examples/ctr-dnn_train.yaml \ - -e local_cluster + -d cpu \ + -e local_cluster ``` ### 集群提交分布式训练<需要用户预先配置好集群环境,本提交命令不包含提交客户端> @@ -48,6 +56,7 @@ cd FleetRec python -m fleetrec.run \ -m fleetrec/examples/ctr-dnn_train.yaml \ + -d cpu \ -e cluster ```