提交 183ae5d0 编写于 作者: T tom__chen

add gpu support for deepfm model

fixed pylint errors
上级 0e27a04d
...@@ -24,16 +24,16 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net ...@@ -24,16 +24,16 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.deepfm import ModelBuilder, AUCMetric from src.deepfm import ModelBuilder, AUCMetric
from src.config import DataConfig, ModelConfig, TrainConfig from src.config import DataConfig, ModelConfig, TrainConfig
from src.dataset import create_dataset from src.dataset import create_dataset, DataType
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
parser = argparse.ArgumentParser(description='CTR Prediction') parser = argparse.ArgumentParser(description='CTR Prediction')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default="Ascend", help='Ascend, GPU, or CPU')
args_opt, _ = parser.parse_known_args() args_opt, _ = parser.parse_known_args()
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
def add_write(file_path, print_str): def add_write(file_path, print_str):
...@@ -47,7 +47,8 @@ if __name__ == '__main__': ...@@ -47,7 +47,8 @@ if __name__ == '__main__':
train_config = TrainConfig() train_config = TrainConfig()
ds_eval = create_dataset(args_opt.dataset_path, train_mode=False, ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
epochs=1, batch_size=train_config.batch_size) epochs=1, batch_size=train_config.batch_size,
data_type=DataType(data_config.data_format))
model_builder = ModelBuilder(ModelConfig, TrainConfig) model_builder = ModelBuilder(ModelConfig, TrainConfig)
train_net, eval_net = model_builder.get_train_eval_net() train_net, eval_net = model_builder.get_train_eval_net()
train_net.set_train() train_net.set_train()
......
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "Please run the script as: "
echo "sh scripts/run_distribute_train.sh DEVICE_NUM DATASET_PATH"
echo "for example: sh scripts/run_distribute_train.sh 8 /dataset_path"
echo "After running the script, the network runs in the background, The log will be generated in log/output.log"
export RANK_SIZE=$1
DATA_URL=$2
rm -rf log
mkdir ./log
cp *.py ./log
cp -r src ./log
cd ./log || exit
env > env.log
mpirun --allow-run-as-root -n $RANK_SIZE \
python -u train.py \
--dataset_path=$DATA_URL \
--ckpt_path="checkpoint" \
--eval_file_name='auc.log' \
--loss_file_name='loss.log' \
--device_target='GPU' \
--do_eval=True > output.log 2>&1 &
...@@ -14,13 +14,14 @@ ...@@ -14,13 +14,14 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
echo "Please run the script as: " echo "Please run the script as: "
echo "sh scripts/run_eval.sh DEVICE_ID DATASET_PATH CHECKPOINT_PATH" echo "sh scripts/run_eval.sh DEVICE_ID DEVICE_TARGET DATASET_PATH CHECKPOINT_PATH"
echo "for example: sh scripts/run_eval.sh 0 /dataset_path /checkpoint_path" echo "for example: sh scripts/run_eval.sh 0 GPU /dataset_path /checkpoint_path"
echo "After running the script, the network runs in the background, The log will be generated in ms_log/eval_output.log" echo "After running the script, the network runs in the background, The log will be generated in ms_log/eval_output.log"
export DEVICE_ID=$1 export DEVICE_ID=$1
DATA_URL=$2 DEVICE_TARGET=$2
CHECKPOINT_PATH=$3 DATA_URL=$3
CHECKPOINT_PATH=$4
mkdir -p ms_log mkdir -p ms_log
CUR_DIR=`pwd` CUR_DIR=`pwd`
...@@ -29,4 +30,5 @@ export GLOG_logtostderr=0 ...@@ -29,4 +30,5 @@ export GLOG_logtostderr=0
python -u eval.py \ python -u eval.py \
--dataset_path=$DATA_URL \ --dataset_path=$DATA_URL \
--checkpoint_path=$CHECKPOINT_PATH > ms_log/eval_output.log 2>&1 & --checkpoint_path=$CHECKPOINT_PATH \
\ No newline at end of file --device_target=$DEVICE_TARGET > ms_log/eval_output.log 2>&1 &
...@@ -14,12 +14,13 @@ ...@@ -14,12 +14,13 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
echo "Please run the script as: " echo "Please run the script as: "
echo "sh scripts/run_standalone_train.sh DEVICE_ID DATASET_PATH" echo "sh scripts/run_standalone_train.sh DEVICE_ID DEVICE_TARGET DATASET_PATH"
echo "for example: sh scripts/run_standalone_train.sh 0 /dataset_path" echo "for example: sh scripts/run_standalone_train.sh 0 GPU /dataset_path"
echo "After running the script, the network runs in the background, The log will be generated in ms_log/output.log" echo "After running the script, the network runs in the background, The log will be generated in ms_log/output.log"
export DEVICE_ID=$1 export DEVICE_ID=$1
DATA_URL=$2 DEVICE_TARGET=$2
DATA_URL=$3
mkdir -p ms_log mkdir -p ms_log
CUR_DIR=`pwd` CUR_DIR=`pwd`
...@@ -31,4 +32,5 @@ python -u train.py \ ...@@ -31,4 +32,5 @@ python -u train.py \
--ckpt_path="checkpoint" \ --ckpt_path="checkpoint" \
--eval_file_name='auc.log' \ --eval_file_name='auc.log' \
--loss_file_name='loss.log' \ --loss_file_name='loss.log' \
--device_target=$DEVICE_TARGET \
--do_eval=True > ms_log/output.log 2>&1 & --do_eval=True > ms_log/output.log 2>&1 &
...@@ -16,11 +16,14 @@ ...@@ -16,11 +16,14 @@
import os import os
import sys import sys
import argparse import argparse
import random
import numpy as np
from mindspore import context, ParallelMode from mindspore import context, ParallelMode
from mindspore.communication.management import init from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
import mindspore.dataset.engine as de
from src.deepfm import ModelBuilder, AUCMetric from src.deepfm import ModelBuilder, AUCMetric
from src.config import DataConfig, ModelConfig, TrainConfig from src.config import DataConfig, ModelConfig, TrainConfig
...@@ -34,24 +37,41 @@ parser.add_argument('--ckpt_path', type=str, default=None, help='Checkpoint path ...@@ -34,24 +37,41 @@ parser.add_argument('--ckpt_path', type=str, default=None, help='Checkpoint path
parser.add_argument('--eval_file_name', type=str, default="./auc.log", help='eval file path') parser.add_argument('--eval_file_name', type=str, default="./auc.log", help='eval file path')
parser.add_argument('--loss_file_name', type=str, default="./loss.log", help='loss file path') parser.add_argument('--loss_file_name', type=str, default="./loss.log", help='loss file path')
parser.add_argument('--do_eval', type=bool, default=True, help='Do evaluation or not.') parser.add_argument('--do_eval', type=bool, default=True, help='Do evaluation or not.')
parser.add_argument('--device_target', type=str, default="Ascend", help='Ascend, GPU, or CPU')
args_opt, _ = parser.parse_known_args() args_opt, _ = parser.parse_known_args()
device_id = int(os.getenv('DEVICE_ID')) rank_size = int(os.environ.get("RANK_SIZE", 1))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
if __name__ == '__main__': if __name__ == '__main__':
data_config = DataConfig() data_config = DataConfig()
model_config = ModelConfig() model_config = ModelConfig()
train_config = TrainConfig() train_config = TrainConfig()
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1: if rank_size > 1:
context.reset_auto_parallel_context() if args_opt.device_target == "Ascend":
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) device_id = int(os.getenv('DEVICE_ID'))
init() context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
rank_id = int(os.environ.get('RANK_ID')) context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True)
init()
rank_id = int(os.environ.get('RANK_ID'))
elif args_opt.device_target == "GPU":
init("nccl")
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
rank_id = get_rank()
else:
print("Unsupported device_target ", args_opt.device_target)
exit()
else: else:
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
rank_size = None rank_size = None
rank_id = None rank_id = None
...@@ -73,6 +93,8 @@ if __name__ == '__main__': ...@@ -73,6 +93,8 @@ if __name__ == '__main__':
callback_list = [time_callback, loss_callback] callback_list = [time_callback, loss_callback]
if train_config.save_checkpoint: if train_config.save_checkpoint:
if rank_size:
train_config.ckpt_file_name_prefix = train_config.ckpt_file_name_prefix + str(get_rank())
config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps, config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps,
keep_checkpoint_max=train_config.keep_checkpoint_max) keep_checkpoint_max=train_config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix, ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册