提交 669a8969 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3421 Add WarpCTC GPU script

Merge pull request !3421 from yangyongjie/master
......@@ -31,7 +31,8 @@ These is an example of training Warpctc with self-generated captcha image datase
└──warpct
├── README.md
├── script
├── run_distribute_train.sh # launch distributed training(8 pcs)
├── run_distribute_train.sh # launch distributed training in Ascend(8 pcs)
├── run_distribute_train_for_gpu.sh # launch distributed training in GPU
├── run_eval.sh # launch evaluation
├── run_process_data.sh # launch dataset generation
└── run_standalone_train.sh # launch standalone training(1 pcs)
......@@ -75,22 +76,31 @@ Parameters for both training and evaluation can be set in config.py.
#### Usage
```
# distributed training
# distributed training in Ascend
Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH]
# distributed training in GPU
Usage: sh run_distribute_train_for_gpu.sh [RANK_SIZE] [DATASET_PATH]
# standalone training
Usage: sh run_standalone_train.sh [DATASET_PATH]
Usage: sh run_standalone_train.sh [DATASET_PATH] [PLATFORM]
```
#### Launch
```
# distribute training example
# distribute training example in Ascend
sh run_distribute_train.sh rank_table.json ../data/train
# standalone training example
sh run_standalone_train.sh ../data/train
# distribute training example in GPU
sh run_distribute_train.sh 8 ../data/train
# standalone training example in Ascend
sh run_standalone_train.sh ../data/train Ascend
# standalone training example in GPU
sh run_standalone_train.sh ../data/train GPU
```
> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html).
......@@ -116,14 +126,17 @@ Epoch: [ 5/ 30], step: [ 98/ 98], loss: [0.0186/0.0186], time: [75199.5809]
```
# evaluation
Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]
```
#### Launch
```
# evaluation example
sh run_eval.sh ../data/test warpctc-30-98.ckpt
# evaluation example in Ascend
sh run_eval.sh ../data/test warpctc-30-98.ckpt Ascend
# evaluation example in GPU
sh run_eval.sh ../data/test warpctc-30-98.ckpt GPU
```
> checkpoint can be produced in training process.
......
......@@ -23,10 +23,10 @@ from mindspore import dataset as de
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.loss import CTCLoss
from src.loss import CTCLoss, CTCLossV2
from src.config import config as cf
from src.dataset import create_dataset
from src.warpctc import StackedRNN
from src.warpctc import StackedRNN, StackedRNNForGPU
from src.metric import WarpCTCAccuracy
random.seed(1)
......@@ -36,30 +36,38 @@ de.config.set_seed(1)
parser = argparse.ArgumentParser(description="Warpctc training")
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.")
parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None")
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
save_graphs=False,
device_id=device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
if args_opt.platform == 'Ascend':
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
if __name__ == '__main__':
max_captcha_digits = cf.max_captcha_digits
input_size = m.ceil(cf.captcha_height / 64) * 64 * 3
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=1, batch_size=cf.batch_size)
dataset = create_dataset(dataset_path=args_opt.dataset_path,
batch_size=cf.batch_size,
device_target=args_opt.platform)
step_size = dataset.get_dataset_size()
# define loss
loss = CTCLoss(max_sequence_length=cf.captcha_width, max_label_length=max_captcha_digits, batch_size=cf.batch_size)
# define net
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
if args_opt.platform == 'Ascend':
loss = CTCLoss(max_sequence_length=cf.captcha_width,
max_label_length=max_captcha_digits,
batch_size=cf.batch_size)
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
else:
loss = CTCLossV2(max_sequence_length=cf.captcha_width, batch_size=cf.batch_size)
net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
# load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
# define model
model = Model(net, loss_fn=loss, metrics={'WarpCTCAccuracy': WarpCTCAccuracy()})
model = Model(net, loss_fn=loss, metrics={'WarpCTCAccuracy': WarpCTCAccuracy(args_opt.platform)})
# start evaluation
res = model.eval(dataset)
res = model.eval(dataset, dataset_sink_mode=args_opt.platform == 'Ascend')
print("result:", res, flush=True)
......@@ -57,6 +57,6 @@ for ((i = 0; i < ${DEVICE_NUM}; i++)); do
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env >env.log
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &>log &
python train.py --platform=Ascend --dataset_path=$PATH2 --run_distribute > log.txt 2>&1 &
cd ..
done
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]; then
echo "Usage: sh run_distribute_train.sh [RANK_SIZE] [DATASET_PATH]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
RANK_SIZE=$1
DATASET_PATH=$(get_real_path $2)
if [ ! -d $DATASET_PATH ]; then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
if [ -d "distribute_train" ]; then
rm -rf ./distribute_train
fi
mkdir ./distribute_train
cp ../*.py ./distribute_train
cp -r ../src ./distribute_train
cd ./distribute_train || exit
mpirun --allow-run-as-root -n $RANK_SIZE \
python train.py \
--dataset_path=$DATASET_PATH \
--platform=GPU \
--run_distribute > log.txt 2>&1 &
cd ..
......@@ -14,8 +14,8 @@
# limitations under the License.
# ============================================================================
if [ $# != 2 ]; then
echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]"
if [ $# != 3 ]; then
echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]"
exit 1
fi
......@@ -29,6 +29,7 @@ get_real_path() {
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
PLATFORM=$3
if [ ! -d $PATH1 ]; then
echo "error: DATASET_PATH=$PATH1 is not a directory"
......@@ -40,21 +41,44 @@ if [ ! -f $PATH2 ]; then
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
run_ascend() {
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ]; then
rm -rf ./eval
if [ -d "eval" ]; then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp -r ../src ./eval
cd ./eval || exit
env >env.log
echo "start evaluation for device $DEVICE_ID"
python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=Ascend > log.txt 2>&1 &
cd ..
}
run_gpu() {
if [ -d "eval" ]; then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp -r ../src ./eval
cd ./eval || exit
env >env.log
python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=GPU > log.txt 2>&1 &
cd ..
}
if [ "Ascend" == $PLATFORM ]; then
run_ascend $PATH1 $PATH2
elif [ "GPU" == $PLATFORM ]; then
run_gpu $PATH1 $PATH2
else
echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU."
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env >env.log
echo "start evaluation for device $DEVICE_ID"
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 &>log &
cd ..
......@@ -14,8 +14,8 @@
# limitations under the License.
# ============================================================================
if [ $# != 1 ]; then
echo "Usage: sh run_standalone_train.sh [DATASET_PATH]"
if [ $# != 2 ]; then
echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PLATFORM]"
exit 1
fi
......@@ -28,27 +28,44 @@ get_real_path() {
}
PATH1=$(get_real_path $1)
PLATFORM=$2
if [ ! -d $PATH1 ]; then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
run_ascend() {
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
echo "start training for device $DEVICE_ID"
env >env.log
python train.py --dataset_path=$1 --platform=Ascend > log.txt 2>&1 &
cd ..
}
run_gpu() {
env >env.log
python train.py --dataset_path=$1 --platform=GPU > log.txt 2>&1 &
cd ..
}
if [ -d "train" ]; then
rm -rf ./train
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env >env.log
python train.py --dataset=$PATH1 &>log &
cd ..
if [ "Ascend" == $PLATFORM ]; then
run_ascend $PATH1
elif [ "GPU" == $PLATFORM ]; then
run_gpu $PATH1
else
echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU."
fi
\ No newline at end of file
......@@ -24,24 +24,25 @@ from PIL import Image
from src.config import config as cf
class _CaptchaDataset():
class _CaptchaDataset:
"""
create train or evaluation dataset for warpctc
Args:
img_root_dir(str): root path of images
max_captcha_digits(int): max number of digits in images.
blank(int): value reserved for blank label, default is 10. When parsing label from image file names, if label
length is less than max_captcha_digits, the remaining labels are padding with blank.
device_target(str): platform of training, support Ascend and GPU.
"""
def __init__(self, img_root_dir, max_captcha_digits, blank=10):
def __init__(self, img_root_dir, max_captcha_digits, device_target='Ascend'):
if not os.path.exists(img_root_dir):
raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
self.img_root_dir = img_root_dir
self.img_names = [i for i in os.listdir(img_root_dir) if i.endswith('.png')]
self.max_captcha_digits = max_captcha_digits
self.blank = blank
self.target = device_target
self.blank = 10 if self.target == 'Ascend' else 0
self.label_length = [len(os.path.splitext(n)[0].split('-')[-1]) for n in self.img_names]
def __len__(self):
return len(self.img_names)
......@@ -54,27 +55,33 @@ class _CaptchaDataset():
image = np.array(im)
label_str = os.path.splitext(img_name)[0]
label_str = label_str[label_str.find('-') + 1:]
label = [int(i) for i in label_str]
label.extend([int(self.blank)] * (self.max_captcha_digits - len(label)))
if self.target == 'Ascend':
label = [int(i) for i in label_str]
label.extend([int(self.blank)] * (self.max_captcha_digits - len(label)))
else:
label = [int(i) + 1 for i in label_str]
length = len(label)
label.extend([int(self.blank)] * (self.max_captcha_digits - len(label)))
label.append(length)
label = np.array(label)
return image, label
def create_dataset(dataset_path, repeat_num=1, batch_size=1):
def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_target='Ascend'):
"""
create train or evaluation dataset for warpctc
Args:
dataset_path(int): dataset path
repeat_num(int): dataset repetition num, default is 1
batch_size(int): batch size of generated dataset, default is 1
num_shards(int): number of devices
shard_id(int): rank id
device_target(str): platform of training, support Ascend and GPU
"""
rank_size = int(os.environ.get("RANK_SIZE")) if os.environ.get("RANK_SIZE") else 1
rank_id = int(os.environ.get("RANK_ID")) if os.environ.get("RANK_ID") else 0
dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits)
ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=rank_size, shard_id=rank_id)
ds.set_dataset_size(m.ceil(len(dataset) / rank_size))
dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits, device_target)
ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=num_shards, shard_id=shard_id)
ds.set_dataset_size(m.ceil(len(dataset) / num_shards))
image_trans = [
vc.Rescale(1.0 / 255.0, 0.0),
vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]),
......@@ -87,6 +94,5 @@ def create_dataset(dataset_path, repeat_num=1, batch_size=1):
ds = ds.map(input_columns=["image"], num_parallel_workers=8, operations=image_trans)
ds = ds.map(input_columns=["label"], num_parallel_workers=8, operations=label_trans)
ds = ds.batch(batch_size)
ds = ds.repeat(repeat_num)
ds = ds.batch(batch_size, drop_remainder=True)
return ds
......@@ -47,3 +47,25 @@ class CTCLoss(_Loss):
labels_values = self.reshape(label, (-1,))
loss, _ = self.ctc_loss(logit, self.labels_indices, labels_values, self.sequence_length)
return loss
class CTCLossV2(_Loss):
"""
CTCLoss definition
Args:
max_sequence_length(int): max number of sequence length. For captcha images, the value is equal to image width
batch_size(int): batch size of input logits
"""
def __init__(self, max_sequence_length, batch_size):
super(CTCLossV2, self).__init__()
self.input_length = Tensor(np.array([max_sequence_length] * batch_size), mstype.int32)
self.reshape = P.Reshape()
self.ctc_loss = P.CTCLossV2()
def construct(self, logit, label):
labels_values = label[:, :-1]
labels_length = label[:, -1]
loss, _ = self.ctc_loss(logit, labels_values, self.input_length, labels_length)
return loss
......@@ -15,19 +15,19 @@
"""Metric for accuracy evaluation."""
from mindspore import nn
BLANK_LABLE = 10
class WarpCTCAccuracy(nn.Metric):
"""
Define accuracy metric for warpctc network.
"""
def __init__(self):
def __init__(self, device_target='Ascend'):
super(WarpCTCAccuracy).__init__()
self._correct_num = 0
self._total_num = 0
self._count = 0
self.device_target = device_target
self.blank = 10 if device_target == 'Ascend' else 0
def clear(self):
self._correct_num = 0
......@@ -39,6 +39,8 @@ class WarpCTCAccuracy(nn.Metric):
y_pred = self._convert_data(inputs[0])
y = self._convert_data(inputs[1])
if self.device_target == 'GPU':
y = y[:, :-1]
self._count += 1
......@@ -54,8 +56,7 @@ class WarpCTCAccuracy(nn.Metric):
raise RuntimeError('Accuary can not be calculated, because the number of samples is 0.')
return self._correct_num / self._total_num
@staticmethod
def _is_eq(pred_lbl, target):
def _is_eq(self, pred_lbl, target):
"""
check whether predict label is equal to target label
"""
......@@ -63,11 +64,10 @@ class WarpCTCAccuracy(nn.Metric):
pred_diff = len(target) - len(pred_lbl)
if pred_diff > 0:
# padding by BLANK_LABLE
pred_lbl.extend([BLANK_LABLE] * pred_diff)
pred_lbl.extend([self.blank] * pred_diff)
return pred_lbl == target
@staticmethod
def _get_prediction(y_pred):
def _get_prediction(self, y_pred):
"""
parse predict result to labels
"""
......@@ -78,11 +78,11 @@ class WarpCTCAccuracy(nn.Metric):
pred_lbls = []
for i in range(batch_size):
idx = indices[:, i]
last_idx = BLANK_LABLE
last_idx = self.blank
pred_lbl = []
for j in range(lens[i]):
cur_idx = idx[j]
if cur_idx not in [last_idx, BLANK_LABLE]:
if cur_idx not in [last_idx, self.blank]:
pred_lbl.append(cur_idx)
last_idx = cur_idx
pred_lbls.append(pred_lbl)
......
......@@ -88,3 +88,52 @@ class StackedRNN(nn.Cell):
output = self.concat((output, h2_after_fc))
return output
class StackedRNNForGPU(nn.Cell):
"""
Define a stacked RNN network which contains two LSTM layers and one full-connect layer.
Args:
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
captcha images.
batch_size(int): batch size of input data, default is 64
hidden_size(int): the hidden size in LSTM layers, default is 512
num_layer(int): the number of layer of LSTM.
"""
def __init__(self, input_size, batch_size=64, hidden_size=512, num_layer=2):
super(StackedRNNForGPU, self).__init__()
self.batch_size = batch_size
self.input_size = input_size
self.num_classes = 11
self.reshape = P.Reshape()
self.cast = P.Cast()
k = (1 / hidden_size) ** 0.5
weight_shape = 4 * hidden_size * (input_size + 3 * hidden_size + 4)
self.weight = Parameter(np.random.uniform(-k, k, (weight_shape, 1, 1)).astype(np.float32), name='weight')
self.h = Tensor(np.zeros(shape=(num_layer, batch_size, hidden_size)).astype(np.float32))
self.c = Tensor(np.zeros(shape=(num_layer, batch_size, hidden_size)).astype(np.float32))
self.lstm = nn.LSTM(input_size, hidden_size, num_layers=2)
self.lstm.weight = self.weight
self.fc_weight = np.random.random((self.num_classes, hidden_size)).astype(np.float32)
self.fc_bias = np.random.random(self.num_classes).astype(np.float32)
self.fc = nn.Dense(in_channels=hidden_size, out_channels=self.num_classes, weight_init=Tensor(self.fc_weight),
bias_init=Tensor(self.fc_bias))
self.fc.to_float(mstype.float32)
self.expand_dims = P.ExpandDims()
self.concat = P.Concat()
self.transpose = P.Transpose()
def construct(self, x):
x = self.transpose(x, (3, 0, 2, 1))
x = self.reshape(x, (-1, self.batch_size, self.input_size))
output, _ = self.lstm(x, (self.h, self.c))
res = ()
for i in range(F.shape(x)[0]):
res += (self.expand_dims(self.fc(output[i]), 0),)
res = self.concat(res)
return res
......@@ -42,7 +42,7 @@ grad_div = C.MultitypeFuncGraph("grad_div")
@grad_div.register("Tensor", "Tensor")
def _grad_div(val, grad):
div = P.Div()
div = P.RealDiv()
mul = P.Mul()
grad = mul(grad, 10.0)
ret = div(grad, val)
......
......@@ -24,12 +24,12 @@ from mindspore import dataset as de
from mindspore.train.model import Model, ParallelMode
from mindspore.nn.wrap import WithLossCell
from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint
from mindspore.communication.management import init
from mindspore.communication.management import init, get_group_size, get_rank
from src.loss import CTCLoss
from src.loss import CTCLoss, CTCLossV2
from src.config import config as cf
from src.dataset import create_dataset
from src.warpctc import StackedRNN
from src.warpctc import StackedRNN, StackedRNNForGPU
from src.warpctc_for_train import TrainOneStepCellWithGradClip
from src.lr_schedule import get_lr
......@@ -38,38 +38,60 @@ np.random.seed(1)
de.config.set_seed(1)
parser = argparse.ArgumentParser(description="Warpctc training")
parser.add_argument("--run_distribute", type=bool, default=False, help="Run distribute, default is false.")
parser.add_argument('--device_num', type=int, default=1, help='Device num, default is 1.')
parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.")
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None')
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
parser.set_defaults(run_distribute=False)
args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
save_graphs=False,
device_id=device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
if args_opt.platform == 'Ascend':
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
if __name__ == '__main__':
lr_scale = 1
if args_opt.run_distribute:
if args_opt.platform == 'Ascend':
init()
lr_scale = 1
device_num = int(os.environ.get("RANK_SIZE"))
rank = int(os.environ.get("RANK_ID"))
else:
init('nccl')
lr_scale = 0.5
device_num = get_group_size()
rank = get_rank()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=args_opt.device_num,
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
init()
else:
device_num = 1
rank = 0
max_captcha_digits = cf.max_captcha_digits
input_size = m.ceil(cf.captcha_height / 64) * 64 * 3
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=1, batch_size=cf.batch_size)
dataset = create_dataset(dataset_path=args_opt.dataset_path, batch_size=cf.batch_size,
num_shards=device_num, shard_id=rank, device_target=args_opt.platform)
step_size = dataset.get_dataset_size()
# define lr
lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * args_opt.device_num
lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * device_num * lr_scale
lr = get_lr(cf.epoch_size, step_size, lr_init)
# define loss
loss = CTCLoss(max_sequence_length=cf.captcha_width, max_label_length=max_captcha_digits, batch_size=cf.batch_size)
# define net
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
# define opt
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum)
if args_opt.platform == 'Ascend':
loss = CTCLoss(max_sequence_length=cf.captcha_width,
max_label_length=max_captcha_digits,
batch_size=cf.batch_size)
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum)
else:
loss = CTCLossV2(max_sequence_length=cf.captcha_width, batch_size=cf.batch_size)
net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
opt = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum)
net = WithLossCell(net, loss)
net = TrainOneStepCellWithGradClip(net, opt).set_train()
# define model
......@@ -79,6 +101,6 @@ if __name__ == '__main__':
if cf.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=cf.save_checkpoint_steps,
keep_checkpoint_max=cf.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="waptctc", directory=cf.save_checkpoint_path, config=config_ck)
ckpt_cb = ModelCheckpoint(prefix="warpctc", directory=cf.save_checkpoint_path, config=config_ck)
callbacks.append(ckpt_cb)
model.train(cf.epoch_size, dataset, callbacks=callbacks)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册