From 183ae5d00954d482dac1f34c28aae7e9b827c2b5 Mon Sep 17 00:00:00 2001 From: tom__chen Date: Tue, 18 Aug 2020 08:29:38 -0800 Subject: [PATCH] add gpu support for deepfm model fixed pylint errors --- model_zoo/official/recommend/deepfm/eval.py | 9 +++-- .../scripts/run_distribute_train_gpu.sh | 38 ++++++++++++++++++ .../recommend/deepfm/scripts/run_eval.sh | 12 +++--- .../deepfm/scripts/run_standalone_train.sh | 8 ++-- model_zoo/official/recommend/deepfm/train.py | 40 ++++++++++++++----- 5 files changed, 86 insertions(+), 21 deletions(-) create mode 100644 model_zoo/official/recommend/deepfm/scripts/run_distribute_train_gpu.sh diff --git a/model_zoo/official/recommend/deepfm/eval.py b/model_zoo/official/recommend/deepfm/eval.py index 0452f73d2..7ade4b618 100644 --- a/model_zoo/official/recommend/deepfm/eval.py +++ b/model_zoo/official/recommend/deepfm/eval.py @@ -24,16 +24,16 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.deepfm import ModelBuilder, AUCMetric from src.config import DataConfig, ModelConfig, TrainConfig -from src.dataset import create_dataset +from src.dataset import create_dataset, DataType sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) parser = argparse.ArgumentParser(description='CTR Prediction') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') - +parser.add_argument('--device_target', type=str, default="Ascend", help='Ascend, GPU, or CPU') args_opt, _ = parser.parse_known_args() device_id = int(os.getenv('DEVICE_ID')) -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id) def add_write(file_path, print_str): @@ -47,7 +47,8 @@ if __name__ == '__main__': train_config = TrainConfig() ds_eval = create_dataset(args_opt.dataset_path, train_mode=False, - epochs=1, batch_size=train_config.batch_size) + epochs=1, batch_size=train_config.batch_size, + data_type=DataType(data_config.data_format)) model_builder = ModelBuilder(ModelConfig, TrainConfig) train_net, eval_net = model_builder.get_train_eval_net() train_net.set_train() diff --git a/model_zoo/official/recommend/deepfm/scripts/run_distribute_train_gpu.sh b/model_zoo/official/recommend/deepfm/scripts/run_distribute_train_gpu.sh new file mode 100644 index 000000000..832cc409d --- /dev/null +++ b/model_zoo/official/recommend/deepfm/scripts/run_distribute_train_gpu.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +echo "Please run the script as: " +echo "sh scripts/run_distribute_train.sh DEVICE_NUM DATASET_PATH" +echo "for example: sh scripts/run_distribute_train.sh 8 /dataset_path" +echo "After running the script, the network runs in the background, The log will be generated in log/output.log" + + +export RANK_SIZE=$1 +DATA_URL=$2 + +rm -rf log +mkdir ./log +cp *.py ./log +cp -r src ./log +cd ./log || exit +env > env.log +mpirun --allow-run-as-root -n $RANK_SIZE \ + python -u train.py \ + --dataset_path=$DATA_URL \ + --ckpt_path="checkpoint" \ + --eval_file_name='auc.log' \ + --loss_file_name='loss.log' \ + --device_target='GPU' \ + --do_eval=True > output.log 2>&1 & diff --git a/model_zoo/official/recommend/deepfm/scripts/run_eval.sh b/model_zoo/official/recommend/deepfm/scripts/run_eval.sh index aa5765da3..5994756a4 100644 --- a/model_zoo/official/recommend/deepfm/scripts/run_eval.sh +++ b/model_zoo/official/recommend/deepfm/scripts/run_eval.sh @@ -14,13 +14,14 @@ # limitations under the License. # ============================================================================ echo "Please run the script as: " -echo "sh scripts/run_eval.sh DEVICE_ID DATASET_PATH CHECKPOINT_PATH" -echo "for example: sh scripts/run_eval.sh 0 /dataset_path /checkpoint_path" +echo "sh scripts/run_eval.sh DEVICE_ID DEVICE_TARGET DATASET_PATH CHECKPOINT_PATH" +echo "for example: sh scripts/run_eval.sh 0 GPU /dataset_path /checkpoint_path" echo "After running the script, the network runs in the background, The log will be generated in ms_log/eval_output.log" export DEVICE_ID=$1 -DATA_URL=$2 -CHECKPOINT_PATH=$3 +DEVICE_TARGET=$2 +DATA_URL=$3 +CHECKPOINT_PATH=$4 mkdir -p ms_log CUR_DIR=`pwd` @@ -29,4 +30,5 @@ export GLOG_logtostderr=0 python -u eval.py \ --dataset_path=$DATA_URL \ - --checkpoint_path=$CHECKPOINT_PATH > ms_log/eval_output.log 2>&1 & \ No newline at end of file + --checkpoint_path=$CHECKPOINT_PATH \ + --device_target=$DEVICE_TARGET > ms_log/eval_output.log 2>&1 & diff --git a/model_zoo/official/recommend/deepfm/scripts/run_standalone_train.sh b/model_zoo/official/recommend/deepfm/scripts/run_standalone_train.sh index fa22b82d3..f2c5c16ee 100644 --- a/model_zoo/official/recommend/deepfm/scripts/run_standalone_train.sh +++ b/model_zoo/official/recommend/deepfm/scripts/run_standalone_train.sh @@ -14,12 +14,13 @@ # limitations under the License. # ============================================================================ echo "Please run the script as: " -echo "sh scripts/run_standalone_train.sh DEVICE_ID DATASET_PATH" -echo "for example: sh scripts/run_standalone_train.sh 0 /dataset_path" +echo "sh scripts/run_standalone_train.sh DEVICE_ID DEVICE_TARGET DATASET_PATH" +echo "for example: sh scripts/run_standalone_train.sh 0 GPU /dataset_path" echo "After running the script, the network runs in the background, The log will be generated in ms_log/output.log" export DEVICE_ID=$1 -DATA_URL=$2 +DEVICE_TARGET=$2 +DATA_URL=$3 mkdir -p ms_log CUR_DIR=`pwd` @@ -31,4 +32,5 @@ python -u train.py \ --ckpt_path="checkpoint" \ --eval_file_name='auc.log' \ --loss_file_name='loss.log' \ + --device_target=$DEVICE_TARGET \ --do_eval=True > ms_log/output.log 2>&1 & diff --git a/model_zoo/official/recommend/deepfm/train.py b/model_zoo/official/recommend/deepfm/train.py index ff110cd5a..656a7bfa0 100644 --- a/model_zoo/official/recommend/deepfm/train.py +++ b/model_zoo/official/recommend/deepfm/train.py @@ -16,11 +16,14 @@ import os import sys import argparse +import random +import numpy as np from mindspore import context, ParallelMode -from mindspore.communication.management import init +from mindspore.communication.management import init, get_rank, get_group_size from mindspore.train.model import Model from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor +import mindspore.dataset.engine as de from src.deepfm import ModelBuilder, AUCMetric from src.config import DataConfig, ModelConfig, TrainConfig @@ -34,24 +37,41 @@ parser.add_argument('--ckpt_path', type=str, default=None, help='Checkpoint path parser.add_argument('--eval_file_name', type=str, default="./auc.log", help='eval file path') parser.add_argument('--loss_file_name', type=str, default="./loss.log", help='loss file path') parser.add_argument('--do_eval', type=bool, default=True, help='Do evaluation or not.') - +parser.add_argument('--device_target', type=str, default="Ascend", help='Ascend, GPU, or CPU') args_opt, _ = parser.parse_known_args() -device_id = int(os.getenv('DEVICE_ID')) -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) +rank_size = int(os.environ.get("RANK_SIZE", 1)) +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) if __name__ == '__main__': data_config = DataConfig() model_config = ModelConfig() train_config = TrainConfig() - rank_size = int(os.environ.get("RANK_SIZE", 1)) if rank_size > 1: - context.reset_auto_parallel_context() - context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) - init() - rank_id = int(os.environ.get('RANK_ID')) + if args_opt.device_target == "Ascend": + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id) + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) + init() + rank_id = int(os.environ.get('RANK_ID')) + elif args_opt.device_target == "GPU": + init("nccl") + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num=get_group_size(), + parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True) + rank_id = get_rank() + else: + print("Unsupported device_target ", args_opt.device_target) + exit() else: + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id) rank_size = None rank_id = None @@ -73,6 +93,8 @@ if __name__ == '__main__': callback_list = [time_callback, loss_callback] if train_config.save_checkpoint: + if rank_size: + train_config.ckpt_file_name_prefix = train_config.ckpt_file_name_prefix + str(get_rank()) config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps, keep_checkpoint_max=train_config.keep_checkpoint_max) ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix, -- GitLab