提交 da7c8cba 编写于 作者: Y yao_yf

wide_and_deep gpu host_device

上级 ddd91219
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# bash run_multigpu_train.sh RANK_SIZE EPOCH_SIZE DATASET
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
RANK_SIZE=$1
EPOCH_SIZE=$2
DATASET=$3
VOCAB_SIZE=$4
EMB_DIM=$5
mpirun --allow-run-as-root -n $RANK_SIZE \
python -s ${self_path}/../train_and_eval_auto_parallel.py \
--device_target="GPU" \
--data_path=$DATASET \
--epochs=$EPOCH_SIZE \
--vocab_size=$VOCAB_SIZE \
--emb_dim=$EMB_DIM \
--dropout_flag=1 \
--host_device_mix=1 > log.txt 2>&1 &
......@@ -18,6 +18,7 @@ import time
from mindspore.train.callback import Callback
from mindspore import context
from mindspore.train import ParallelMode
from mindspore.communication.management import get_rank
def add_write(file_path, out_str):
"""
......@@ -52,7 +53,14 @@ class LossCallBack(Callback):
wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), cb_params.net_outputs[1].asnumpy()
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
cur_num = cb_params.cur_step_num
print("===loss===", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss, flush=True)
rank_id = 0
parallel_mode = context.get_auto_parallel_context("parallel_mode")
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL,
ParallelMode.DATA_PARALLEL):
rank_id = get_rank()
print("===loss===", rank_id, cb_params.cur_epoch_num, cur_step_in_epoch,
wide_loss, deep_loss, flush=True)
# raise ValueError
if self._per_print_times != 0 and cur_num % self._per_print_times == 0 and self.config is not None:
......@@ -99,13 +107,18 @@ class EvalCallBack(Callback):
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
context.set_auto_parallel_context(strategy_ckpt_save_file="",
strategy_ckpt_load_file="./strategy_train.ckpt")
rank_id = 0
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL,
ParallelMode.DATA_PARALLEL):
rank_id = get_rank()
start_time = time.time()
out = self.model.eval(self.eval_dataset, dataset_sink_mode=(not self.host_device_mix))
end_time = time.time()
eval_time = int(end_time - start_time)
time_str = time.strftime("%Y-%m-%d %H:%M%S", time.localtime())
out_str = "{}==== EvalCallBack model.eval(): {}; eval_time: {}s".format(time_str, out.values(), eval_time)
out_str = "{} == Rank: {} == EvalCallBack model.eval(): {}; eval_time: {}s".\
format(time_str, rank_id, out.values(), eval_time)
print(out_str)
self.eval_values = out.values()
add_write(self.eval_file_name, out_str)
......@@ -201,6 +201,7 @@ class WideDeepModel(nn.Cell):
self.cast = P.Cast()
if is_auto_parallel and host_device_mix:
self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),))
self.dense_layer_1.dropout.dropout.set_strategy(((1, get_group_size()),))
self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1)))
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim,
slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE)
......
......@@ -32,13 +32,6 @@ from src.metrics import AUCMetric
from src.config import WideDeepConfig
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(variable_memory_max_size="24GB")
context.set_context(enable_sparse=True)
cost_model_context.set_cost_model_context(multi_subgraphs=True)
init()
def get_WideDeep_net(config):
"""
......@@ -131,6 +124,14 @@ def train_and_eval(config):
if __name__ == "__main__":
wide_deep_config = WideDeepConfig()
wide_deep_config.argparse_init()
context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True)
context.set_context(variable_memory_max_size="24GB")
context.set_context(enable_sparse=True)
cost_model_context.set_cost_model_context(multi_subgraphs=True)
if wide_deep_config.device_target == "Ascend":
init("hccl")
elif wide_deep_config.device_target == "GPU":
init("nccl")
if wide_deep_config.host_device_mix == 1:
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True)
else:
......
......@@ -16,6 +16,7 @@
import os
import sys
import numpy as np
from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.callback import TimeMonitor
......@@ -68,6 +69,7 @@ def train_and_eval(config):
"""
train_and_eval
"""
np.random.seed(1000)
data_path = config.data_path
epochs = config.epochs
print("epochs is {}".format(epochs))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册