提交 4a01d236 编写于 作者: M Megvii Engine Team

ci(mgb/python_module): fix doctest syntax error

GitOrigin-RevId: e98ca810ea7ca46de473ac14ae9078cfc9a4bd32
上级 b3ac87e4
...@@ -82,7 +82,7 @@ class Function(metaclass=ABCMeta): ...@@ -82,7 +82,7 @@ class Function(metaclass=ABCMeta):
self.save_for_backward(y) self.save_for_backward(y)
return y return y
def backward(self. output_grads): def backward(self, output_grads):
(y, ) = self.saved_tensors (y, ) = self.saved_tensors
return output_grads * y * (1-y) return output_grads * y * (1-y)
......
...@@ -21,26 +21,27 @@ class Sequential(Module): ...@@ -21,26 +21,27 @@ class Sequential(Module):
.. testcode:: .. testcode::
from collections import OrderedDict from collections import OrderedDict
import numpy as np import numpy as np
import megengine.nn as nn import megengine.functional as F
import megengine.nn.functional as F from megengine.module import Sequential, Linear
from megengine import tensor
batch_size = 64 batch_size = 64
data = nn.Input("data", shape=(batch_size, 1, 28, 28), dtype=np.float32, value=np.zeros((batch_size, 1, 28, 28))) data = tensor(np.zeros((batch_size, 1, 28, 28)), dtype=np.float32)
label = nn.Input("label", shape=(batch_size,), dtype=np.int32, value=np.zeros(batch_size,)) label = tensor(np.zeros(batch_size,), dtype=np.int32)
data = data.reshape(batch_size, -1) data = data.reshape(batch_size, -1)
net0 = nn.Sequential( net0 = Sequential(
nn.Linear(28 * 28, 320), Linear(28 * 28, 320),
nn.Linear(320, 10) Linear(320, 10)
) )
pred0 = net0(data) pred0 = net0(data)
modules = OrderedDict() modules = OrderedDict()
modules["fc0"] = nn.Linear(28 * 28, 320) modules["fc0"] = Linear(28 * 28, 320)
modules["fc1"] = nn.Linear(320, 10) modules["fc1"] = Linear(320, 10)
net1 = nn.Sequential(modules) net1 = Sequential(modules)
pred1 = net1(data) pred1 = net1(data)
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册