提交 8a484dbd 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!812 [CR] add lenet train and eval st case

Merge pull request !812 from jinyaohui/train_eval
......@@ -13,18 +13,26 @@
# limitations under the License.
# ============================================================================
import os
import pytest
import numpy as np
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.nn.optim import Momentum
import mindspore.context as context
from mindspore.ops import operations as P
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn import Dense
from mindspore.common.initializer import initializer
import mindspore.nn as nn
from mindspore.nn import Dense, TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Momentum
from mindspore.nn.metrics import Accuracy
from mindspore.train import Model
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.model_zoo.lenet import LeNet5
from mindspore.train.callback import LossMonitor
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.transforms.vision import Inter
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
......@@ -64,7 +72,7 @@ class LeNet(nn.Cell):
def multisteplr(total_steps, gap, base_lr=0.9, gamma=0.1, dtype=mstype.float32):
lr = []
for step in range(total_steps):
lr_ = base_lr * gamma ** (step//gap)
lr_ = base_lr * gamma ** (step // gap)
lr.append(lr_)
return Tensor(np.array(lr), dtype)
......@@ -90,3 +98,60 @@ def test_train_lenet():
loss = train_network(data, label)
losses.append(loss)
print(losses)
def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
"""
create dataset for train or test
"""
# define dataset
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081
# define map operations
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
# apply map operations on images
mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)
# apply DatasetOps
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_train_and_eval_lenet():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", enable_mem_reuse=False)
network = LeNet5(10)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Training ==============")
ds_train = create_dataset(os.path.join('/home/workspace/mindspore_dataset/mnist', "train"), 32, 1)
model.train(1, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=True)
print("============== Starting Testing ==============")
ds_eval = create_dataset(os.path.join('/home/workspace/mindspore_dataset/mnist', "test"), 32, 1)
acc = model.eval(ds_eval, dataset_sink_mode=True)
print("============== Accuracy:{} ==============".format(acc))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册