diff --git a/example/data_processing.py b/example/data_processing.py index 8b4007edba0f8b392a33cdbacca008d0593755e9..d1b93eb0e8f9d619489f98ed90a196fd48416895 100644 --- a/example/data_processing.py +++ b/example/data_processing.py @@ -37,10 +37,10 @@ def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1, rescale_op = CV.Rescale(rescale, shift) hwc2chw_op = CV.HWC2CHW() type_cast_op = C.TypeCast(mstype.int32) - one_hot_enco = C.OneHot(10) # apply map operations on images if not sparse: + one_hot_enco = C.OneHot(10) ds1 = ds1.map(input_columns="label", operations=one_hot_enco, num_parallel_workers=num_parallel_workers) type_cast_op = C.TypeCast(mstype.float32) diff --git a/example/mnist_demo/lenet5_net.py b/example/mnist_demo/lenet5_net.py index 0606015b170605d56800ec4d6d069ad82f3386e7..7f5ead321cce8441bf204b5a1102f1d1712bcdbc 100644 --- a/example/mnist_demo/lenet5_net.py +++ b/example/mnist_demo/lenet5_net.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import mindspore.nn as nn -import mindspore.ops.operations as P +from mindspore import nn from mindspore.common.initializer import TruncatedNormal @@ -30,7 +29,7 @@ def fc_with_initialize(input_channels, out_channels): def weight_variable(): - return TruncatedNormal(0.2) + return TruncatedNormal(0.02) class LeNet5(nn.Cell): @@ -46,7 +45,7 @@ class LeNet5(nn.Cell): self.fc3 = fc_with_initialize(84, 10) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) - self.reshape = P.Reshape() + self.flatten = nn.Flatten() def construct(self, x): x = self.conv1(x) @@ -55,7 +54,7 @@ class LeNet5(nn.Cell): x = self.conv2(x) x = self.relu(x) x = self.max_pool2d(x) - x = self.reshape(x, (-1, 16*5*5)) + x = self.flatten(x) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) diff --git a/example/mnist_demo/mnist_train.py b/example/mnist_demo/mnist_train.py index d9ef839a34c5e1d235d011760b0d6e0b919f6654..eeaba3f80a3b5f945d8fdbb8f21e088509c59844 100644 --- a/example/mnist_demo/mnist_train.py +++ b/example/mnist_demo/mnist_train.py @@ -20,10 +20,7 @@ from mindspore import context, Tensor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train import Model -import mindspore.ops.operations as P from mindspore.nn.metrics import Accuracy -from mindspore.ops import functional as F -from mindspore.common import dtype as mstype from mindarmour.utils.logger import LogUtil @@ -32,26 +29,7 @@ from lenet5_net import LeNet5 sys.path.append("..") from data_processing import generate_mnist_dataset LOGGER = LogUtil.get_instance() -TAG = 'Lenet5_train' - - -class CrossEntropyLoss(nn.Cell): - """ - Define loss for network - """ - def __init__(self): - super(CrossEntropyLoss, self).__init__() - self.cross_entropy = P.SoftmaxCrossEntropyWithLogits() - self.mean = P.ReduceMean() - self.one_hot = P.OneHot() - self.on_value = Tensor(1.0, mstype.float32) - self.off_value = Tensor(0.0, mstype.float32) - - def construct(self, logits, label): - label = self.one_hot(label, F.shape(logits)[1], self.on_value, self.off_value) - loss = self.cross_entropy(logits, label)[0] - loss = self.mean(loss, (-1,)) - return loss +TAG = "Lenet5_train" def mnist_train(epoch_size, batch_size, lr, momentum): @@ -66,23 +44,29 @@ def mnist_train(epoch_size, batch_size, lr, momentum): batch_size=batch_size, repeat_size=1) network = LeNet5() - network.set_train() - net_loss = CrossEntropyLoss() + net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, + reduction="mean") net_opt = nn.Momentum(network.trainable_params(), lr, momentum) - config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10) - ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory='./trained_ckpt_file/', config=config_ck) + config_ck = CheckpointConfig(save_checkpoint_steps=1875, + keep_checkpoint_max=10) + ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", + directory="./trained_ckpt_file/", + config=config_ck) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) LOGGER.info(TAG, "============== Starting Training ==============") - model.train(epoch_size, ds, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False) # train + model.train(epoch_size, ds, callbacks=[ckpoint_cb, LossMonitor()], + dataset_sink_mode=False) LOGGER.info(TAG, "============== Starting Testing ==============") - param_dict = load_checkpoint("trained_ckpt_file/checkpoint_lenet-10_1875.ckpt") + ckpt_file_name = "trained_ckpt_file/checkpoint_lenet-10_1875.ckpt" + param_dict = load_checkpoint(ckpt_file_name) load_param_into_net(network, param_dict) - ds_eval = generate_mnist_dataset(os.path.join(mnist_path, "test"), batch_size=batch_size) - acc = model.eval(ds_eval) + ds_eval = generate_mnist_dataset(os.path.join(mnist_path, "test"), + batch_size=batch_size) + acc = model.eval(ds_eval, dataset_sink_mode=False) LOGGER.info(TAG, "============== Accuracy: %s ==============", acc) if __name__ == '__main__': - mnist_train(10, 32, 0.001, 0.9) + mnist_train(10, 32, 0.01, 0.9)