diff --git a/fleet_rec/run.py b/fleet_rec/run.py index bed88bae1ec80234adf2fdae6d9c5f7a3cf25292..41ee2c5996f451ba2a90b2b8681b0ee33b671333 100644 --- a/fleet_rec/run.py +++ b/fleet_rec/run.py @@ -169,23 +169,34 @@ def local_mpi_engine(args): return launch +def get_abs_model(model): + if model.startswith("fleetrec."): + fleet_base = envs.get_runtime_environ("PACKAGE_BASE") + workspace_dir = model.split("fleetrec.")[1].replace(".", "/") + path = os.path.join(fleet_base, workspace_dir, "config.yaml") + print("use built-in config: {} for model: {}".format(model, path)) + else: + if not os.path.isfile(model): + raise IOError("model config: {} invalid".format(model)) + path = model + return path + + if __name__ == "__main__": parser = argparse.ArgumentParser(description='fleet-rec run') parser.add_argument("-m", "--model", type=str) parser.add_argument("-e", "--engine", type=str, choices=["single", "local_cluster", "cluster"]) parser.add_argument("-d", "--device", type=str, choices=["cpu", "gpu"], default="cpu") + abs_dir = os.path.dirname(os.path.abspath(__file__)) + envs.set_runtime_environs({"PACKAGE_BASE": abs_dir}) + args = parser.parse_args() args.engine = args.engine.upper() args.device = args.device.upper() - - if not os.path.isfile(args.model): - raise IOError("argument model: {} do not exist".format(args.model)) + args.model = get_abs_model(args.model) engine_registry() - abs_dir = os.path.dirname(os.path.abspath(__file__)) - envs.set_runtime_environs({"PACKAGE_BASE": abs_dir}) - which_engine = get_engine(args.engine, args.device) engine = which_engine(args) diff --git a/models/rank/dnn/config.yaml b/models/rank/dnn/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..66f5053a0f2364a9038e6b9b5f3ccaad6481d6ca --- /dev/null +++ b/models/rank/dnn/config.yaml @@ -0,0 +1,47 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +train: + trainer: + # for cluster training + strategy: "async" + + epochs: 10 + workspace: "fleetrec.models.rank.dnn" + + reader: + batch_size: 2 + class: "{workspace}/../criteo_reader.py" + train_data_path: "{workspace}/data/train" + + model: + models: "{workspace}/model.py" + hyper_parameters: + sparse_inputs_slots: 27 + sparse_feature_number: 1000001 + sparse_feature_dim: 9 + dense_input_dim: 13 + fc_sizes: [512, 256, 128, 32] + learning_rate: 0.001 + optimizer: adam + + save: + increment: + dirname: "increment" + epoch_interval: 2 + save_last: True + inference: + dirname: "inference" + epoch_interval: 4 + save_last: True diff --git a/setup.py b/setup.py index 84091b676a8e3ca288fc12772b98ad972bbe94aa..7ecc4a916d251773233bdb0f79e28a108c38e6cc 100644 --- a/setup.py +++ b/setup.py @@ -2,23 +2,14 @@ setup for fleet-rec. """ import os -import sys from setuptools import setup, find_packages import tempfile import shutil -if sys.version_info.major == 2: - requires = [ - "paddlepaddle == 1.7.2", - # "netron >= 0.0.0", - "pyyaml >= 5.1.1" - ] -else: - requires = [ - "paddlepaddle >= 0.0.0", - # "netron >= 0.0.0", - "pyyaml >= 5.1.1" - ] +requires = [ + "paddlepaddle == 1.7.2", + "pyyaml >= 5.1.1" +] about = {} about["__title__"] = "fleet-rec" @@ -37,7 +28,6 @@ def run_cmd(command): def build(dirname): - package_dir = os.path.dirname(os.path.abspath(__file__)) run_cmd("cp -r {}/* {}".format(package_dir, dirname)) run_cmd("mkdir {}".format(os.path.join(dirname, "fleetrec")))