提交 fe1b7358 编写于 作者: K kswang 提交者: 高东海

add cpu st lenet

上级 3b4d8fce
......@@ -12,25 +12,44 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Function:
test network
Usage:
python test_network_main.py --net lenet --target Davinci
"""
import os
import time
import pytest
import numpy as np
import argparse
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.nn import TrainOneStepCell, WithLossCell
import mindspore.context as context
from mindspore.nn.optim import Momentum
from models.lenet import LeNet
from models.resnetv1_5 import resnet50
from models.alexnet import AlexNet
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor
class LeNet(nn.Cell):
def __init__(self):
super(LeNet, self).__init__()
self.relu = P.ReLU()
self.batch_size = 32
self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
self.fc1 = nn.Dense(400, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
def construct(self, input_x):
output = self.conv1(input_x)
output = self.relu(output)
output = self.pool(output)
output = self.conv2(output)
output = self.relu(output)
output = self.pool(output)
output = self.reshape(output, (self.batch_size, -1))
output = self.fc1(output)
output = self.relu(output)
output = self.fc2(output)
output = self.relu(output)
output = self.fc3(output)
return output
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
def train(net, data, label):
......@@ -48,15 +67,6 @@ def train(net, data, label):
print("+++++++++++++++++++++++++++")
assert res
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_resnet50():
data = Tensor(np.ones([32, 3 ,224, 224]).astype(np.float32) * 0.01)
label = Tensor(np.ones([32]).astype(np.int32))
net = resnet50(32, 10)
train(net, data, label)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
......@@ -65,12 +75,3 @@ def test_lenet():
label = Tensor(np.ones([32]).astype(np.int32))
net = LeNet()
train(net, data, label)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_alexnet():
data = Tensor(np.ones([32, 3 ,227, 227]).astype(np.float32) * 0.01)
label = Tensor(np.ones([32]).astype(np.int32))
net = AlexNet()
train(net, data, label)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册