提交 6cab1dd7 编写于 作者: M Megvii Engine Team

refactor(mge/sdk): update xor-deploy

GitOrigin-RevId: 372c37cdc5116834b47344c16c2a443c6e7ebfdb
上级 53ec6b83
import numpy as np
import megengine as mge
import megengine.autodiff as ad
import megengine.functional as F
import megengine.module as M
import megengine.optimizer as optim
......@@ -35,57 +36,54 @@ class XORNet(M.Module):
return x
@trace(symbolic=True)
def train_fun(data, label, net=None, opt=None):
net.train()
pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label)
opt.backward(loss)
return pred, loss
@trace(symbolic=True)
def val_fun(data, label, net=None):
net.eval()
pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label)
return pred, loss
@trace(symbolic=True)
def pred_fun(data, net=None):
net.eval()
pred = net(data)
pred_normalized = F.softmax(pred)
return pred_normalized
def main():
if not mge.is_cuda_available():
mge.set_default_device("cpux")
net = XORNet()
gm = ad.GradManager().attach(net.parameters())
opt = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
batch_size = 64
train_dataset = minibatch_generator(batch_size)
val_dataset = minibatch_generator(batch_size)
data = mge.tensor()
label = mge.tensor(np.zeros((batch_size,)), dtype=np.int32)
def train_fun(data, label):
opt.clear_grad()
with gm:
pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label)
gm.backward(loss)
opt.step()
return pred, loss
def val_fun(data, label):
pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label)
return pred, loss
@trace(symbolic=True, capture_as_const=True)
def pred_fun(data):
pred = net(data)
pred_normalized = F.softmax(pred)
return pred_normalized
data = np.random.random((batch_size, 2)).astype(np.float32)
label = np.zeros((batch_size,)).astype(np.int32)
train_loss = []
val_loss = []
for step, minibatch in enumerate(train_dataset):
if step > 1000:
break
data.set_value(minibatch["data"])
label.set_value(minibatch["label"])
opt.zero_grad()
_, loss = train_fun(data, label, net=net, opt=opt)
data = minibatch["data"]
label = minibatch["label"]
net.train()
_, loss = train_fun(data, label)
train_loss.append((step, loss.numpy()))
if step % 50 == 0:
minibatch = next(val_dataset)
_, loss = val_fun(data, label, net=net)
net.eval()
_, loss = val_fun(data, label)
loss = loss.numpy()[0]
val_loss.append((step, loss))
print("Step: {} loss={}".format(step, loss))
......@@ -108,8 +106,10 @@ def main():
]
)
data.set_value(test_data)
out = pred_fun(data, net=net)
# tracing only accepts tensor as input
data = mge.tensor(test_data, dtype=np.float32)
net.eval()
out = pred_fun(data)
pred_output = out.numpy()
pred_label = np.argmax(pred_output, 1)
......@@ -125,11 +125,8 @@ def main():
model_name = "xornet_deploy.mge"
if pred_fun.enabled:
print("Dump model as {}".format(model_name))
pred_fun.dump(model_name, arg_names=["data"])
else:
print("pred_fun must be run with trace enabled in order to dump model")
print("Dump model as {}".format(model_name))
pred_fun.dump(model_name, arg_names=["data"])
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册