提交 4a01d236 编写于 作者: M Megvii Engine Team

ci(mgb/python_module): fix doctest syntax error

GitOrigin-RevId: e98ca810ea7ca46de473ac14ae9078cfc9a4bd32
上级 b3ac87e4
......@@ -82,7 +82,7 @@ class Function(metaclass=ABCMeta):
self.save_for_backward(y)
return y
def backward(self. output_grads):
def backward(self, output_grads):
(y, ) = self.saved_tensors
return output_grads * y * (1-y)
......
......@@ -21,26 +21,27 @@ class Sequential(Module):
.. testcode::
from collections import OrderedDict
import numpy as np
import megengine.nn as nn
import megengine.nn.functional as F
import megengine.functional as F
from megengine.module import Sequential, Linear
from megengine import tensor
batch_size = 64
data = nn.Input("data", shape=(batch_size, 1, 28, 28), dtype=np.float32, value=np.zeros((batch_size, 1, 28, 28)))
label = nn.Input("label", shape=(batch_size,), dtype=np.int32, value=np.zeros(batch_size,))
data = tensor(np.zeros((batch_size, 1, 28, 28)), dtype=np.float32)
label = tensor(np.zeros(batch_size,), dtype=np.int32)
data = data.reshape(batch_size, -1)
net0 = nn.Sequential(
nn.Linear(28 * 28, 320),
nn.Linear(320, 10)
net0 = Sequential(
Linear(28 * 28, 320),
Linear(320, 10)
)
pred0 = net0(data)
modules = OrderedDict()
modules["fc0"] = nn.Linear(28 * 28, 320)
modules["fc1"] = nn.Linear(320, 10)
net1 = nn.Sequential(modules)
modules["fc0"] = Linear(28 * 28, 320)
modules["fc1"] = Linear(320, 10)
net1 = Sequential(modules)
pred1 = net1(data)
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册