提交 45d20ee2 编写于 作者: T tangwei

add build-in config for model

上级 c6b1cf3b
......@@ -169,23 +169,34 @@ def local_mpi_engine(args):
return launch
def get_abs_model(model):
if model.startswith("fleetrec."):
fleet_base = envs.get_runtime_environ("PACKAGE_BASE")
workspace_dir = model.split("fleetrec.")[1].replace(".", "/")
path = os.path.join(fleet_base, workspace_dir, "config.yaml")
print("use built-in config: {} for model: {}".format(model, path))
else:
if not os.path.isfile(model):
raise IOError("model config: {} invalid".format(model))
path = model
return path
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='fleet-rec run')
parser.add_argument("-m", "--model", type=str)
parser.add_argument("-e", "--engine", type=str, choices=["single", "local_cluster", "cluster"])
parser.add_argument("-d", "--device", type=str, choices=["cpu", "gpu"], default="cpu")
abs_dir = os.path.dirname(os.path.abspath(__file__))
envs.set_runtime_environs({"PACKAGE_BASE": abs_dir})
args = parser.parse_args()
args.engine = args.engine.upper()
args.device = args.device.upper()
if not os.path.isfile(args.model):
raise IOError("argument model: {} do not exist".format(args.model))
args.model = get_abs_model(args.model)
engine_registry()
abs_dir = os.path.dirname(os.path.abspath(__file__))
envs.set_runtime_environs({"PACKAGE_BASE": abs_dir})
which_engine = get_engine(args.engine, args.device)
engine = which_engine(args)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
train:
trainer:
# for cluster training
strategy: "async"
epochs: 10
workspace: "fleetrec.models.rank.dnn"
reader:
batch_size: 2
class: "{workspace}/../criteo_reader.py"
train_data_path: "{workspace}/data/train"
model:
models: "{workspace}/model.py"
hyper_parameters:
sparse_inputs_slots: 27
sparse_feature_number: 1000001
sparse_feature_dim: 9
dense_input_dim: 13
fc_sizes: [512, 256, 128, 32]
learning_rate: 0.001
optimizer: adam
save:
increment:
dirname: "increment"
epoch_interval: 2
save_last: True
inference:
dirname: "inference"
epoch_interval: 4
save_last: True
......@@ -2,23 +2,14 @@
setup for fleet-rec.
"""
import os
import sys
from setuptools import setup, find_packages
import tempfile
import shutil
if sys.version_info.major == 2:
requires = [
"paddlepaddle == 1.7.2",
# "netron >= 0.0.0",
"pyyaml >= 5.1.1"
]
else:
requires = [
"paddlepaddle >= 0.0.0",
# "netron >= 0.0.0",
"pyyaml >= 5.1.1"
]
requires = [
"paddlepaddle == 1.7.2",
"pyyaml >= 5.1.1"
]
about = {}
about["__title__"] = "fleet-rec"
......@@ -37,7 +28,6 @@ def run_cmd(command):
def build(dirname):
package_dir = os.path.dirname(os.path.abspath(__file__))
run_cmd("cp -r {}/* {}".format(package_dir, dirname))
run_cmd("mkdir {}".format(os.path.join(dirname, "fleetrec")))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册