提交 a1347264 编写于 作者: L linqingke

mass add gpu support.

上级 5b7875ba
...@@ -57,9 +57,6 @@ The overall network architecture of MASS is shown below, which is Transformer(Va ...@@ -57,9 +57,6 @@ The overall network architecture of MASS is shown below, which is Transformer(Va
MASS is consisted of 6-layer encoder and 6-layer decoder with 1024 embedding/hidden size, and 4096 intermediate size between feed forward network which has two full connection layers. MASS is consisted of 6-layer encoder and 6-layer decoder with 1024 embedding/hidden size, and 4096 intermediate size between feed forward network which has two full connection layers.
![Transformer architecture](https://cdn.analyticsvidhya.com/wp-content/uploads/2019/06/Screenshot-from-2019-06-17-19-53-10.png)
# Dataset # Dataset
Dataset used: Dataset used:
...@@ -124,7 +121,8 @@ MASS script and code structure are as follow: ...@@ -124,7 +121,8 @@ MASS script and code structure are as follow:
│ ├──all.bpe.codes // BPE codes table(this file should be generated by user). │ ├──all.bpe.codes // BPE codes table(this file should be generated by user).
│ ├──all_en.dict.bin // Learned vocabulary file(this file should be generated by user). │ ├──all_en.dict.bin // Learned vocabulary file(this file should be generated by user).
├── scripts ├── scripts
│ ├──run.sh // Train & evaluate model script. │ ├──run_ascend.sh // Ascend train & evaluate model script.
│ ├──run_gpu.sh // GPU train & evaluate model script.
│ ├──learn_subword.sh // Learn BPE codes. │ ├──learn_subword.sh // Learn BPE codes.
│ ├──stop_training.sh // Stop training. │ ├──stop_training.sh // Stop training.
├── requirements.txt // Requirements of third party package. ├── requirements.txt // Requirements of third party package.
...@@ -329,18 +327,24 @@ Almost all of the options and arguments needed could be assigned conveniently, i ...@@ -329,18 +327,24 @@ Almost all of the options and arguments needed could be assigned conveniently, i
For more detailed information about the attributes, refer to the file `config/config.py`. For more detailed information about the attributes, refer to the file `config/config.py`.
## Training & Evaluation process ## Training & Evaluation process
For training a model, the shell script `run.sh` is all you need. In this scripts, the environment variable is set and the training script `train.py` under `mass` is executed. For training a model, the shell script `run_ascend.sh` or `run_gpu.sh` is all you need. In this scripts, the environment variable is set and the training script `train.py` under `mass` is executed.
You may start a task training with single device or multiple devices by assigning the options and run the command in bash: You may start a task training with single device or multiple devices by assigning the options and run the command in bash:
```bash
sh run.sh [--options] Ascend:
```ascend
sh run_ascend.sh [--options]
```
GPU:
```gpu
sh run_gpu.sh [--options]
``` ```
The usage is shown as bellow: The usage of `run_ascend.sh` is shown as bellow:
```text ```text
Usage: run.sh [-h, --help] [-t, --task <CHAR>] [-n, --device_num <N>] Usage: run_ascend.sh [-h, --help] [-t, --task <CHAR>] [-n, --device_num <N>]
[-i, --device_id <N>] [-j, --hccl_json <FILE>] [-i, --device_id <N>] [-j, --hccl_json <FILE>]
[-c, --config <FILE>] [-o, --output <FILE>] [-c, --config <FILE>] [-o, --output <FILE>]
[-v, --vocab <FILE>] [-v, --vocab <FILE>]
options: options:
-h, --help show usage -h, --help show usage
...@@ -350,20 +354,49 @@ options: ...@@ -350,20 +354,49 @@ options:
-j, --hccl_json rank table file used for training with multiple devices: FILE. -j, --hccl_json rank table file used for training with multiple devices: FILE.
-c, --config configuration file as shown in the path 'mass/config': FILE. -c, --config configuration file as shown in the path 'mass/config': FILE.
-o, --output assign output file of inference: FILE. -o, --output assign output file of inference: FILE.
-v, --vocab set the vocabulary" -v, --vocab set the vocabulary.
-m, --metric set the metric.
``` ```
Notes: Be sure to assign the hccl_json file while running a distributed-training. Notes: Be sure to assign the hccl_json file while running a distributed-training.
The usage of `run_gpu.sh` is shown as bellow:
```text
Usage: run_gpu.sh [-h, --help] [-t, --task <CHAR>] [-n, --device_num <N>]
[-i, --device_id <N>] [-c, --config <FILE>]
[-o, --output <FILE>] [-v, --vocab <FILE>]
options:
-h, --help show usage
-t, --task select task: CHAR, 't' for train and 'i' for inference".
-n, --device_num device number used for training: N, default is 1.
-i, --device_id device id used for training with single device: N, 0<=N<=7, default is 0.
-c, --config configuration file as shown in the path 'mass/config': FILE.
-o, --output assign output file of inference: FILE.
-v, --vocab set the vocabulary.
-m, --metric set the metric.
```
The command followed shows a example for training with 2 devices. The command followed shows a example for training with 2 devices.
```bash Ascend:
sh run.sh --task t --device_num 2 --hccl_json /{path}/rank_table.json --config /{path}/config.json ```ascend
sh run_ascend.sh --task t --device_num 2 --hccl_json /{path}/rank_table.json --config /{path}/config.json
``` ```
ps. Discontinuous device id is not supported in `run.sh` at present, device id in `rank_table.json` must start from 0. ps. Discontinuous device id is not supported in `run_ascend.sh` at present, device id in `rank_table.json` must start from 0.
GPU:
```gpu
sh run_gpu.sh --task t --device_num 2 --config /{path}/config.json
```
If use a single chip, it would be like this: If use a single chip, it would be like this:
```bash Ascend:
sh run.sh --task t --device_num 1 --device_id 0 --config /{path}/config.json ```ascend
sh run_ascend.sh --task t --device_num 1 --device_id 0 --config /{path}/config.json
```
GPU:
```gpu
sh run_gpu.sh --task t --device_num 1 --device_id 0 --config /{path}/config.json
``` ```
...@@ -441,9 +474,6 @@ During testing, we use the fine-turned model to predict the result, and adopt a ...@@ -441,9 +474,6 @@ During testing, we use the fine-turned model to predict the result, and adopt a
get the most possible prediction results. get the most possible prediction results.
![MASS framework](https://www.microsoft.com/en-us/research/uploads/prod/2019/06/MASS-Fig-2.png)
## Performance ## Performance
### Results ### Results
...@@ -543,11 +573,18 @@ For pre-training a model, config the options in `config.json` firstly: ...@@ -543,11 +573,18 @@ For pre-training a model, config the options in `config.json` firstly:
- Set other arguments including dataset configurations and network configurations. - Set other arguments including dataset configurations and network configurations.
- If you have a trained model already, assign the `existed_ckpt` to the checkpoint file. - If you have a trained model already, assign the `existed_ckpt` to the checkpoint file.
Run the shell script `run.sh` as followed: If you use the ascend chip, run the shell script `run_ascend.sh` as followed:
```bash ```ascend
sh run.sh -t t -n 1 -i 1 -c /mass/config/config.json sh run_ascend.sh -t t -n 1 -i 1 -c /mass/config/config.json
``` ```
You can also run the shell script `run_gpu.sh` on gpu as followed:
```gpu
sh run_gpu.sh -t t -n 1 -i 1 -c /mass/config/config.json
```
Get the log and output files under the path `./train_mass_*/`, and the model file under the path assigned in the `config/config.json` file. Get the log and output files under the path `./train_mass_*/`, and the model file under the path assigned in the `config/config.json` file.
## Fine-tuning ## Fine-tuning
...@@ -558,10 +595,18 @@ For fine-tuning a model, config the options in `config.json` firstly: ...@@ -558,10 +595,18 @@ For fine-tuning a model, config the options in `config.json` firstly:
- Assign the `ckpt_prefix` and `ckpt_path` under `checkpoint_path` node to save the model files. - Assign the `ckpt_prefix` and `ckpt_path` under `checkpoint_path` node to save the model files.
- Set other arguments including dataset configurations and network configurations. - Set other arguments including dataset configurations and network configurations.
Run the shell script `run.sh` as followed: If you use the ascend chip, run the shell script `run_ascend.sh` as followed:
```bash
sh run.sh -t t -n 1 -i 1 -c config/config.json ```ascend
sh run_ascend.sh -t t -n 1 -i 1 -c config/config.json
```
You can also run the shell script `run_gpu.sh` on gpu as followed:
```gpu
sh run_gpu.sh -t t -n 1 -i 1 -c config/config.json
``` ```
Get the log and output files under the path `./train_mass_*/`, and the model file under the path assigned in the `config/config.json` file. Get the log and output files under the path `./train_mass_*/`, and the model file under the path assigned in the `config/config.json` file.
## Inference ## Inference
...@@ -573,10 +618,16 @@ For inference, config the options in `config.json` firstly: ...@@ -573,10 +618,16 @@ For inference, config the options in `config.json` firstly:
- Assign the `ckpt_prefix` and `ckpt_path` under `checkpoint_path` node to save the model files. - Assign the `ckpt_prefix` and `ckpt_path` under `checkpoint_path` node to save the model files.
- Set other arguments including dataset configurations and network configurations. - Set other arguments including dataset configurations and network configurations.
Run the shell script `run.sh` as followed: If you use the ascend chip, run the shell script `run_ascend.sh` as followed:
```bash ```bash
sh run.sh -t i -n 1 -i 1 -c config/config.json -o {outputfile} sh run_ascend.sh -t i -n 1 -i 1 -c config/config.json -o {outputfile}
```
You can also run the shell script `run_gpu.sh` on gpu as followed:
```gpu
sh run_gpu.sh -t i -n 1 -i 1 -c config/config.json -o {outputfile}
``` ```
# Description of random situation # Description of random situation
......
...@@ -13,10 +13,12 @@ ...@@ -13,10 +13,12 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Evaluation api.""" """Evaluation api."""
import os
import argparse import argparse
import pickle import pickle
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore import context
from config import TransformerConfig from config import TransformerConfig
from src.transformer import infer, infer_ppl from src.transformer import infer, infer_ppl
...@@ -32,6 +34,8 @@ parser.add_argument("--output", type=str, required=True, ...@@ -32,6 +34,8 @@ parser.add_argument("--output", type=str, required=True,
help="Result file path.") help="Result file path.")
parser.add_argument("--metric", type=str, default='rouge', parser.add_argument("--metric", type=str, default='rouge',
help='Set eval method.') help='Set eval method.')
parser.add_argument("--platform", type=str, required=True,
help="model working platform.")
def get_config(config): def get_config(config):
...@@ -46,6 +50,16 @@ if __name__ == '__main__': ...@@ -46,6 +50,16 @@ if __name__ == '__main__':
vocab = Dictionary.load_from_persisted_dict(args.vocab) vocab = Dictionary.load_from_persisted_dict(args.vocab)
_config = get_config(args.config) _config = get_config(args.config)
device_id = os.getenv('DEVICE_ID', None)
if device_id is None:
device_id = 0
device_id = int(device_id)
context.set_context(
mode=context.GRAPH_MODE,
device_target=args.platform,
reserve_class_name_in_scope=False,
device_id=device_id)
if args.metric == 'rouge': if args.metric == 'rouge':
result = infer(_config) result = infer(_config)
else: else:
......
...@@ -165,10 +165,10 @@ do ...@@ -165,10 +165,10 @@ do
echo $task echo $task
if [ "$task" == "train" ] if [ "$task" == "train" ]
then then
python train.py --config ${configurations##*/} >>log.log 2>&1 & python train.py --config ${configurations##*/} --platform Ascend >>log.log 2>&1 &
elif [ "$task" == "infer" ] elif [ "$task" == "infer" ]
then then
python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} --metric ${metric} >>log_infer.log 2>&1 & python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} --metric ${metric} --platform Ascend >>log_infer.log 2>&1 &
fi fi
cd ../ cd ../
done done
#!/usr/bin/env bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
options=`getopt -u -o ht:n:i::o:v:m: -l help,task:,device_num:,device_id:,config:,output:,vocab:,metric: -- "$@"`
eval set -- "$options"
echo $options
echo_help()
{
echo "Usage:"
echo "bash train.sh [-h] [-t t|i] [-n N] [-i N] [-j FILE] [-c FILE] [-o FILE] [-v FILE]"
echo "options:"
echo " -h --help show usage"
echo " -t --task select task, 't' for training and 'i' for inference"
echo " -n --device_num training with N devices"
echo " -i --device_id training with device i"
echo " -c --config set the configuration file"
echo " -o --output set the output file of inference"
echo " -v --vocab set the vocabulary"
echo " -m --metric set the metric"
}
set_device_id()
{
while [ -n "$1" ]
do
if [[ "$1" == "-i" || "$1" == "--device_id" ]]
then
if [[ $2 -ge 0 && $2 -le 7 ]]
then
export DEVICE_ID=$2
fi
break
fi
shift
done
}
while [ -n "$1" ]
do
case "$1" in
-h|--help)
echo_help
shift
;;
-t|--task)
echo "task:"
if [ "$2" == "t" ]
then
task=train
elif [ "$2" == "i" ]
then
task=infer
fi
shift 2
;;
-n|--device_num)
echo "device_num"
if [ $2 -eq 1 ]
then
set_device_id $options
elif [ $2 -gt 1 ]
then
export RANK_SIZE=$2
fi
shift 2
;;
-i|--device_id)
echo "set device id"
export DEVICE_ID=$2
shift 2
;;
-c|--config)
echo "config";
configurations=$2
shift 2
;;
-o|--output)
echo "output";
output=$2
shift 2
;;
-v|--vocab)
echo "vocab";
vocab=$2
shift 2
;;
-m|--metric)
echo "metric";
metric=$2
shift 2
;;
--)
shift
break
;;
*)
shift
;;
esac
done
file_path=$(cd "$(dirname $0)" || exit; pwd)
if [ $RANK_SIZE -gt 1 ]
then
echo "Working on $RANK_SIZE device"
fi
echo "Working on file ${task}_mass_$DEVICE_ID"
cd $file_path || exit
cd ../ || exit
rm -rf ./${task}_mass_$DEVICE_ID
mkdir ./${task}_mass_$DEVICE_ID
cp train.py ./${task}_mass_$DEVICE_ID
cp eval.py ./${task}_mass_$DEVICE_ID
cp $configurations ./${task}_mass_$DEVICE_ID
if [ $vocab ]
then
cp $vocab ./${task}_mass_$DEVICE_ID
fi
cd ./${task}_mass_$DEVICE_ID || exit
env > log.log
echo $task
if [ "$task" == "train" ]
then
if [ $RANK_SIZE -gt 1 ]
then
mpirun -n $RANK_SIZE python train.py --config ${configurations##*/} --platform GPU >>log.log 2>&1 &
fi
python train.py --config ${configurations##*/} --platform GPU >>log.log 2>&1 &
elif [ "$task" == "infer" ]
then
python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} --metric ${metric} --platform GPU >>log_infer.log 2>&1 &
fi
cd ../
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""Transformer for training.""" """Transformer for training."""
from mindspore import nn from mindspore import nn
import mindspore.context as context
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import composite as C from mindspore.ops import composite as C
...@@ -204,11 +205,16 @@ class TransformerNetworkWithLoss(nn.Cell): ...@@ -204,11 +205,16 @@ class TransformerNetworkWithLoss(nn.Cell):
grad_scale = C.MultitypeFuncGraph("grad_scale") grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal() reciprocal = P.Reciprocal()
@grad_scale.register("Tensor", "Tensor") @grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad): def tensor_grad_scale(scale, grad):
return grad * F.cast(reciprocal(scale), F.dtype(grad)) return grad * F.cast(reciprocal(scale), F.dtype(grad))
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
grad_overflow = P.FloatStatus()
@_grad_overflow.register("Tensor")
def _tensor_grad_overflow(grad):
return grad_overflow(grad)
class TransformerTrainOneStepWithLossScaleCell(nn.Cell): class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
""" """
...@@ -251,9 +257,16 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): ...@@ -251,9 +257,16 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.clip_gradients = ClipGradients() self.clip_gradients = ClipGradients()
self.cast = P.Cast() self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus() if context.get_context("device_target") == "GPU":
self.get_status = P.NPUGetFloatStatus() self.gpu_target = True
self.clear_before_grad = P.NPUClearFloatStatus() self.float_status = P.FloatStatus()
self.addn = P.AddN()
self.reshape = P.Reshape()
else:
self.gpu_target = False
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_status = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False) self.reduce_sum = P.ReduceSum(keep_dims=False)
self.depend_parameter_use = P.ControlDepend(depend_mode=1) self.depend_parameter_use = P.ControlDepend(depend_mode=1)
self.base = Tensor(1, mstype.float32) self.base = Tensor(1, mstype.float32)
...@@ -304,14 +317,18 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): ...@@ -304,14 +317,18 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
target_mask, target_mask,
label_ids, label_ids,
label_weights) label_weights)
# Alloc status.
init = self.alloc_status() init = False
# Clear overflow buffer. if not self.gpu_target:
self.clear_before_grad(init) # init overflow buffer
init = self.alloc_status()
# clear overflow buffer
self.clear_status(init)
if sens is None: if sens is None:
scaling_sens = self.loss_scale scaling_sens = self.loss_scale
else: else:
scaling_sens = sens scaling_sens = sens
grads = self.grad(self.network, weights)(source_ids, grads = self.grad(self.network, weights)(source_ids,
source_mask, source_mask,
target_ids, target_ids,
...@@ -323,11 +340,21 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): ...@@ -323,11 +340,21 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE) grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
if self.reducer_flag: if self.reducer_flag:
# Apply grad reducer on grads. # Apply grad reducer on grads.
grads = self.grad_reducer(grads) grads = self.grad_reducer(grads)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,)) # get the overflow buffer
if not self.gpu_target:
self.get_status(init)
# sum overflow buffer elements, 0:not overflow , >0:overflow
flag_sum = self.reduce_sum(init, (0,))
else:
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
flag_sum = self.addn(flag_sum)
# convert flag_sum to scalar
flag_sum = self.reshape(flag_sum, (()))
if self.is_distributed: if self.is_distributed:
# Sum overflow flag over devices. # Sum overflow flag over devices.
......
...@@ -49,11 +49,13 @@ class LossCallBack(Callback): ...@@ -49,11 +49,13 @@ class LossCallBack(Callback):
file_name = "./loss.log" file_name = "./loss.log"
with open(file_name, "a+") as f: with open(file_name, "a+") as f:
time_stamp_current = self._get_ms_timestamp() time_stamp_current = self._get_ms_timestamp()
f.write("time: {}, epoch: {}, step: {}, outputs are {}.\n".format( f.write("time: {}, epoch: {}, step: {}, outputs are {},{},{}.\n".format(
time_stamp_current - self.time_stamp_first, time_stamp_current - self.time_stamp_first,
cb_params.cur_epoch_num, cb_params.cur_epoch_num,
cb_params.cur_step_num, cb_params.cur_step_num,
str(cb_params.net_outputs) str(cb_params.net_outputs[0].asnumpy()),
str(cb_params.net_outputs[1].asnumpy()),
str(cb_params.net_outputs[2].asnumpy())
)) ))
@staticmethod @staticmethod
......
...@@ -24,7 +24,7 @@ from mindspore.common.tensor import Tensor ...@@ -24,7 +24,7 @@ from mindspore.common.tensor import Tensor
from mindspore.nn import Momentum from mindspore.nn import Momentum
from mindspore.nn.optim import Adam, Lamb from mindspore.nn.optim import Adam, Lamb
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import DynamicLossScaleManager from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore import context, ParallelMode, Parameter from mindspore import context, ParallelMode, Parameter
from mindspore.communication import management as MultiAscend from mindspore.communication import management as MultiAscend
...@@ -41,18 +41,7 @@ from src.utils.lr_scheduler import polynomial_decay_scheduler, BertLearningRate ...@@ -41,18 +41,7 @@ from src.utils.lr_scheduler import polynomial_decay_scheduler, BertLearningRate
parser = argparse.ArgumentParser(description='MASS train entry point.') parser = argparse.ArgumentParser(description='MASS train entry point.')
parser.add_argument("--config", type=str, required=True, help="model config json file path.") parser.add_argument("--config", type=str, required=True, help="model config json file path.")
parser.add_argument("--platform", type=str, required=True, help="model working platform.")
device_id = os.getenv('DEVICE_ID', None)
if device_id is None:
raise RuntimeError("`DEVICE_ID` can not be None.")
device_id = int(device_id)
context.set_context(
mode=context.GRAPH_MODE,
device_target="Ascend",
reserve_class_name_in_scope=False,
device_id=device_id)
def get_config(config): def get_config(config):
config = TransformerConfig.from_json_file(config) config = TransformerConfig.from_json_file(config)
...@@ -79,12 +68,11 @@ def _train(model, config: TransformerConfig, ...@@ -79,12 +68,11 @@ def _train(model, config: TransformerConfig,
if pre_training_dataset is not None: if pre_training_dataset is not None:
print(" | Start pre-training job.") print(" | Start pre-training job.")
epoch_size = config.epochs * pre_training_dataset.get_dataset_size() // config.dataset_sink_step
if os.getenv("RANK_SIZE") is not None and int(os.getenv("RANK_SIZE")) > 1: if os.getenv("RANK_SIZE") is not None and int(os.getenv("RANK_SIZE")) > 1:
print(f" | Rank {MultiAscend.get_rank()} Call model train.") print(f" | Rank {MultiAscend.get_rank()} Call model train.")
model.train(epoch_size, pre_training_dataset, model.train(config.epochs, pre_training_dataset,
callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode, callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode,
sink_size=config.dataset_sink_step) sink_size=config.dataset_sink_step)
...@@ -97,9 +85,8 @@ def _train(model, config: TransformerConfig, ...@@ -97,9 +85,8 @@ def _train(model, config: TransformerConfig,
if fine_tune_dataset is not None: if fine_tune_dataset is not None:
print(" | Start fine-tuning job.") print(" | Start fine-tuning job.")
epoch_size = config.epochs * fine_tune_dataset.get_dataset_size() // config.dataset_sink_step
model.train(epoch_size, fine_tune_dataset, model.train(config.epochs, fine_tune_dataset,
callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode, callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode,
sink_size=config.dataset_sink_step) sink_size=config.dataset_sink_step)
...@@ -114,7 +101,8 @@ def _train(model, config: TransformerConfig, ...@@ -114,7 +101,8 @@ def _train(model, config: TransformerConfig,
def _build_training_pipeline(config: TransformerConfig, def _build_training_pipeline(config: TransformerConfig,
pre_training_dataset=None, pre_training_dataset=None,
fine_tune_dataset=None, fine_tune_dataset=None,
test_dataset=None): test_dataset=None,
platform="Ascend"):
""" """
Build training pipeline. Build training pipeline.
...@@ -198,14 +186,15 @@ def _build_training_pipeline(config: TransformerConfig, ...@@ -198,14 +186,15 @@ def _build_training_pipeline(config: TransformerConfig,
else: else:
raise ValueError(f"optimizer only support `adam` and `momentum` now.") raise ValueError(f"optimizer only support `adam` and `momentum` now.")
# Dynamic loss scale. # loss scale.
scale_manager = DynamicLossScaleManager(init_loss_scale=config.init_loss_scale, if platform == "Ascend":
scale_factor=config.loss_scale_factor, scale_manager = DynamicLossScaleManager(init_loss_scale=config.init_loss_scale,
scale_window=config.scale_window) scale_factor=config.loss_scale_factor,
net_with_grads = TransformerTrainOneStepWithLossScaleCell( scale_window=config.scale_window)
network=net_with_loss, optimizer=optimizer, else:
scale_update_cell=scale_manager.get_update_cell() scale_manager = FixedLossScaleManager(loss_scale=1.0, drop_overflow_update=True)
) net_with_grads = TransformerTrainOneStepWithLossScaleCell(network=net_with_loss, optimizer=optimizer,
scale_update_cell=scale_manager.get_update_cell())
net_with_grads.set_train(True) net_with_grads.set_train(True)
model = Model(net_with_grads) model = Model(net_with_grads)
loss_monitor = LossCallBack(config) loss_monitor = LossCallBack(config)
...@@ -236,9 +225,12 @@ def _build_training_pipeline(config: TransformerConfig, ...@@ -236,9 +225,12 @@ def _build_training_pipeline(config: TransformerConfig,
callbacks=callbacks) callbacks=callbacks)
def _setup_parallel_env(): def _setup_parallel_env(platform):
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
MultiAscend.init() if platform == "GPU":
MultiAscend.init("nccl")
else:
MultiAscend.init()
context.set_auto_parallel_context( context.set_auto_parallel_context(
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
device_num=MultiAscend.get_group_size(), device_num=MultiAscend.get_group_size(),
...@@ -247,14 +239,14 @@ def _setup_parallel_env(): ...@@ -247,14 +239,14 @@ def _setup_parallel_env():
) )
def train_parallel(config: TransformerConfig): def train_parallel(config: TransformerConfig, platform: "Ascend"):
""" """
Train model with multi ascend chips. Train model with multi ascend chips.
Args: Args:
config (TransformerConfig): Config for MASS model. config (TransformerConfig): Config for MASS model.
""" """
_setup_parallel_env() _setup_parallel_env(platform)
print(f" | Starting training on {os.getenv('RANK_SIZE', None)} devices.") print(f" | Starting training on {os.getenv('RANK_SIZE', None)} devices.")
...@@ -286,10 +278,11 @@ def train_parallel(config: TransformerConfig): ...@@ -286,10 +278,11 @@ def train_parallel(config: TransformerConfig):
_build_training_pipeline(config=config, _build_training_pipeline(config=config,
pre_training_dataset=pre_train_dataset, pre_training_dataset=pre_train_dataset,
fine_tune_dataset=fine_tune_dataset, fine_tune_dataset=fine_tune_dataset,
test_dataset=test_dataset) test_dataset=test_dataset,
platform=platform)
def train_single(config: TransformerConfig): def train_single(config: TransformerConfig, platform: "Ascend"):
""" """
Train model on single device. Train model on single device.
...@@ -316,7 +309,8 @@ def train_single(config: TransformerConfig): ...@@ -316,7 +309,8 @@ def train_single(config: TransformerConfig):
_build_training_pipeline(config=config, _build_training_pipeline(config=config,
pre_training_dataset=pre_train_dataset, pre_training_dataset=pre_train_dataset,
fine_tune_dataset=fine_tune_dataset, fine_tune_dataset=fine_tune_dataset,
test_dataset=test_dataset) test_dataset=test_dataset,
platform=platform)
def _check_args(config): def _check_args(config):
...@@ -327,9 +321,20 @@ def _check_args(config): ...@@ -327,9 +321,20 @@ def _check_args(config):
if __name__ == '__main__': if __name__ == '__main__':
args, _ = parser.parse_known_args()
device_id = os.getenv('DEVICE_ID', None)
if device_id is None:
device_id = 0
device_id = int(device_id)
context.set_context(
mode=context.GRAPH_MODE,
device_target=args.platform,
reserve_class_name_in_scope=False,
device_id=device_id)
_rank_size = os.getenv('RANK_SIZE') _rank_size = os.getenv('RANK_SIZE')
args, _ = parser.parse_known_args()
_check_args(args.config) _check_args(args.config)
_config = get_config(args.config) _config = get_config(args.config)
...@@ -337,6 +342,6 @@ if __name__ == '__main__': ...@@ -337,6 +342,6 @@ if __name__ == '__main__':
context.set_context(save_graphs=_config.save_graphs) context.set_context(save_graphs=_config.save_graphs)
if _rank_size is not None and int(_rank_size) > 1: if _rank_size is not None and int(_rank_size) > 1:
train_parallel(_config) train_parallel(_config, args.platform)
else: else:
train_single(_config) train_single(_config, args.platform)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册