提交 66df8437 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!9 Update mnist_lenet5 example

Merge pull request !9 from pkuliuliu/master
......@@ -37,10 +37,10 @@ def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1,
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
one_hot_enco = C.OneHot(10)
# apply map operations on images
if not sparse:
one_hot_enco = C.OneHot(10)
ds1 = ds1.map(input_columns="label", operations=one_hot_enco,
num_parallel_workers=num_parallel_workers)
type_cast_op = C.TypeCast(mstype.float32)
......
......@@ -11,8 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore import nn
from mindspore.common.initializer import TruncatedNormal
......@@ -30,7 +29,7 @@ def fc_with_initialize(input_channels, out_channels):
def weight_variable():
return TruncatedNormal(0.2)
return TruncatedNormal(0.02)
class LeNet5(nn.Cell):
......@@ -46,7 +45,7 @@ class LeNet5(nn.Cell):
self.fc3 = fc_with_initialize(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
......@@ -55,7 +54,7 @@ class LeNet5(nn.Cell):
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.reshape(x, (-1, 16*5*5))
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
......
......@@ -20,10 +20,7 @@ from mindspore import context, Tensor
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train import Model
import mindspore.ops.operations as P
from mindspore.nn.metrics import Accuracy
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindarmour.utils.logger import LogUtil
......@@ -32,26 +29,7 @@ from lenet5_net import LeNet5
sys.path.append("..")
from data_processing import generate_mnist_dataset
LOGGER = LogUtil.get_instance()
TAG = 'Lenet5_train'
class CrossEntropyLoss(nn.Cell):
"""
Define loss for network
"""
def __init__(self):
super(CrossEntropyLoss, self).__init__()
self.cross_entropy = P.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
def construct(self, logits, label):
label = self.one_hot(label, F.shape(logits)[1], self.on_value, self.off_value)
loss = self.cross_entropy(logits, label)[0]
loss = self.mean(loss, (-1,))
return loss
TAG = "Lenet5_train"
def mnist_train(epoch_size, batch_size, lr, momentum):
......@@ -66,23 +44,29 @@ def mnist_train(epoch_size, batch_size, lr, momentum):
batch_size=batch_size, repeat_size=1)
network = LeNet5()
network.set_train()
net_loss = CrossEntropyLoss()
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True,
reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), lr, momentum)
config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory='./trained_ckpt_file/', config=config_ck)
config_ck = CheckpointConfig(save_checkpoint_steps=1875,
keep_checkpoint_max=10)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
directory="./trained_ckpt_file/",
config=config_ck)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
LOGGER.info(TAG, "============== Starting Training ==============")
model.train(epoch_size, ds, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False) # train
model.train(epoch_size, ds, callbacks=[ckpoint_cb, LossMonitor()],
dataset_sink_mode=False)
LOGGER.info(TAG, "============== Starting Testing ==============")
param_dict = load_checkpoint("trained_ckpt_file/checkpoint_lenet-10_1875.ckpt")
ckpt_file_name = "trained_ckpt_file/checkpoint_lenet-10_1875.ckpt"
param_dict = load_checkpoint(ckpt_file_name)
load_param_into_net(network, param_dict)
ds_eval = generate_mnist_dataset(os.path.join(mnist_path, "test"), batch_size=batch_size)
acc = model.eval(ds_eval)
ds_eval = generate_mnist_dataset(os.path.join(mnist_path, "test"),
batch_size=batch_size)
acc = model.eval(ds_eval, dataset_sink_mode=False)
LOGGER.info(TAG, "============== Accuracy: %s ==============", acc)
if __name__ == '__main__':
mnist_train(10, 32, 0.001, 0.9)
mnist_train(10, 32, 0.01, 0.9)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册