From 16544b37d688b2dd3aacba396c6fa293a5d66920 Mon Sep 17 00:00:00 2001 From: wukesong Date: Tue, 23 Jun 2020 22:14:38 +0800 Subject: [PATCH] modify --- model_zoo/alexnet/src/dataset.py | 2 +- model_zoo/lenet/src/lenet.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/model_zoo/alexnet/src/dataset.py b/model_zoo/alexnet/src/dataset.py index fe1822579..6e9f310be 100644 --- a/model_zoo/alexnet/src/dataset.py +++ b/model_zoo/alexnet/src/dataset.py @@ -16,11 +16,11 @@ Produce the dataset """ -from config import alexnet_cfg as cfg import mindspore.dataset as ds import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.transforms.vision.c_transforms as CV from mindspore.common import dtype as mstype +from .config import alexnet_cfg as cfg def create_dataset_mnist(data_path, batch_size=32, repeat_size=1, status="train"): diff --git a/model_zoo/lenet/src/lenet.py b/model_zoo/lenet/src/lenet.py index 3864315db..a570c6cd9 100644 --- a/model_zoo/lenet/src/lenet.py +++ b/model_zoo/lenet/src/lenet.py @@ -43,11 +43,12 @@ class LeNet5(nn.Cell): Args: num_class (int): Num classes. Default: 10. + channel (int): Num channels. Default: 1. Returns: Tensor, output tensor Examples: - >>> LeNet(num_class=10) + >>> LeNet(num_class=10, channel=1) """ def __init__(self, num_class=10, channel=1): -- GitLab