提交 fc3eca84 编写于 作者: M Megvii Engine Team

feat(mge/imperative/jit): add dump input shape and xornet example

GitOrigin-RevId: 5e2acd405224daf1d320873b9b049676cb4cf081
上级 1569cab0
......@@ -78,9 +78,9 @@ class Graph(_imperative_rt.ComputingGraph):
opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self)
return opnode.outputs[0]
def make_h2d(self, *, dtype, device):
def make_h2d(self, *, dtype, device, shape=None, name=None):
device = as_device(device).to_c()
return self._wrap(_imperative_rt.make_h2d(self, device, dtype))
return self._wrap(_imperative_rt.make_h2d(self, device, dtype, shape, name))
def dump(*args):
......
......@@ -51,6 +51,7 @@ class TensorInfo:
"value_read",
"device",
"dtype",
"shape",
"bound_data",
# resources for execution
"varnode",
......@@ -107,8 +108,8 @@ class trace:
self._active_tensors = weakref.WeakSet()
self._tensor_remaps = None
self._inputs_to_restore = None
self._args_bindings = None
self._kwargs_bindings = None
self._arg_bindings = None
self._kwarg_bindings = None
self._output_bindings = None
self._output_names = None
......@@ -329,9 +330,7 @@ class trace:
links = ()
if self._capture_as_const:
for h in itertools.chain(
self._args_bindings, self._kwargs_bindings.values()
):
for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()):
info = self._tinfo[h]
opnode = info.data_setter = G.InputNode(
device=info.device, dtype=info.dtype, graph=graph
......@@ -434,15 +433,19 @@ class trace:
h2v = {}
graph = G.Graph()
for i, h in enumerate(self._args_bindings):
for i, h in enumerate(self._arg_bindings):
info = self._tinfo[h]
h2v[h] = graph.make_h2d(dtype=info.dtype, device=info.device)
if arg_names:
h2v[h].name = arg_names[i]
for k, h in self._kwargs_bindings.items():
h2v[h] = graph.make_h2d(
dtype=info.dtype,
device=info.device,
shape=info.shape,
name=arg_names[i] if arg_names else None,
)
for k, h in self._kwarg_bindings.items():
info = self._tinfo[h]
h2v[h] = graph.make_h2d(dtype=info.dtype, device=info.device)
h2v[h].name = k
h2v[h] = graph.make_h2d(
dtype=info.dtype, device=info.device, shape=info.shape, name=k
)
for op, ihandles, ohandles in self._seq:
ivars = []
......@@ -479,11 +482,12 @@ class trace:
info.external = False
info.device = x.device
info.dtype = x.dtype
info.shape = x.shape
TraceMixin._TraceMixin__inject(x, h)
self._inputs_to_restore.append(x)
return h
self._args_bindings = []
self._arg_bindings = []
for i, x in enumerate(args):
x = find_raw_tensor(x)
if x is None:
......@@ -491,20 +495,20 @@ class trace:
"positional arguments should all be tensor "
"but args[%d] cannot be recognized as one" % i
)
self._args_bindings.append(record_input(x))
self._arg_bindings.append(record_input(x))
self._kwargs_bindings = {}
self._kwarg_bindings = {}
for k, x in kwargs.items():
x = find_raw_tensor(x)
if x is not None:
self._kwargs_bindings[k] = record_input(x)
self._kwarg_bindings[k] = record_input(x)
else:
if len(args) != len(self._args_bindings):
if len(args) != len(self._arg_bindings):
raise TraceMismatchError("positional argument length mismatch")
self._tensor_remaps = {}
for i, (h, x) in enumerate(zip(self._args_bindings, args)):
for i, (h, x) in enumerate(zip(self._arg_bindings, args)):
x = find_raw_tensor(x)
if x is None:
raise TypeError(
......@@ -524,9 +528,9 @@ class trace:
x = find_raw_tensor(x)
if x is not None:
kwargs_tensors[k] = x
if set(kwargs_tensors) != set(self._kwargs_bindings):
too_many = set(kwargs_tensors) - set(self._kwargs_bindings)
too_few = set(self._kwargs_bindings) - set(kwargs_tensors)
if set(kwargs_tensors) != set(self._kwarg_bindings):
too_many = set(kwargs_tensors) - set(self._kwarg_bindings)
too_few = set(self._kwarg_bindings) - set(kwargs_tensors)
if too_many:
raise TraceMismatchError(
"keyword arguments found to be tensor this time "
......@@ -537,7 +541,7 @@ class trace:
"keyword arguments found to be non-tensor this time "
"but were tensor previously: %s" % " ".join(too_few)
)
for k, h in self._kwargs_bindings.items():
for k, h in self._kwarg_bindings.items():
x = kwargs_tensors[k]
info = self._tinfo[h]
if x.dtype != info.dtype:
......
......@@ -237,7 +237,7 @@ void init_graph_rt(py::module m) {
return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node();
});
m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, std::optional<std::string> name) {
m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, TensorShape shape, std::optional<std::string> name) {
if (!cn.valid()) {
throw py::type_error("device must be valid");
}
......@@ -248,8 +248,8 @@ void init_graph_rt(py::module m) {
if (name) {
config.name(*name);
}
return opr::Host2DeviceCopy::make(graph, std::make_shared<HostTensorND>(cn, dtype), config).node();
}, py::arg(), py::arg(), py::arg(), py::arg() = py::none());
return opr::Host2DeviceCopy::make(graph, std::make_shared<HostTensorND>(cn, shape, dtype), config).node();
}, py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::none());
m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback,
const CompNode& comp_node,
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import contextlib
import os
import tempfile
import numpy as np
import megengine as mge
import megengine.functional as F
import megengine.module as M
import megengine.optimizer as optim
from megengine import tensor
from megengine.jit import trace
@contextlib.contextmanager
def mkstemp():
fd, path = tempfile.mkstemp()
try:
os.close(fd)
yield path
finally:
os.remove(path)
def minibatch_generator(batch_size):
while True:
inp_data = np.zeros((batch_size, 2))
label = np.zeros(batch_size, dtype=np.int32)
for i in range(batch_size):
inp_data[i, :] = np.random.rand(2) * 2 - 1
label[i] = 1 if np.prod(inp_data[i]) < 0 else 0
yield {"data": inp_data.astype(np.float32), "label": label.astype(np.int32)}
class XORNet(M.Module):
def __init__(self):
self.mid_dim = 14
self.num_class = 2
super().__init__()
self.fc0 = M.Linear(self.num_class, self.mid_dim, bias=True)
self.fc1 = M.Linear(self.mid_dim, self.mid_dim, bias=True)
self.fc2 = M.Linear(self.mid_dim, self.num_class, bias=True)
def forward(self, x):
x = self.fc0(x)
x = F.tanh(x)
x = self.fc1(x)
x = F.tanh(x)
x = self.fc2(x)
return x
def test_xornet_trace_dump():
net = XORNet()
opt = optim.SGD(net.parameters(requires_grad=True), lr=0.01, momentum=0.9)
batch_size = 64
train_dataset = minibatch_generator(batch_size)
val_dataset = minibatch_generator(batch_size)
@trace
def train_fun(data, label):
with opt.record():
net.train()
pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label)
opt.backward(loss)
return pred, loss
@trace
def val_fun(data, label):
net.eval()
pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label)
return pred, loss
@trace(symbolic=True, capture_as_const=True)
def pred_fun(data):
net.eval()
pred = net(data)
pred_normalized = F.softmax(pred)
return pred_normalized
train_loss = []
val_loss = []
for step, minibatch in enumerate(train_dataset):
if step > 100:
break
data = tensor(minibatch["data"])
label = tensor(minibatch["label"])
opt.zero_grad()
_, loss = train_fun(data, label)
train_loss.append((step, loss.numpy()))
if step % 50 == 0:
minibatch = next(val_dataset)
_, loss = val_fun(data, label)
loss = loss.numpy()[0]
val_loss.append((step, loss))
print("Step: {} loss={}".format(step, loss))
opt.step()
test_data = np.array(
[
(0.5, 0.5),
(0.3, 0.7),
(0.1, 0.9),
(-0.5, -0.5),
(-0.3, -0.7),
(-0.9, -0.1),
(0.5, -0.5),
(0.3, -0.7),
(0.9, -0.1),
(-0.5, 0.5),
(-0.3, 0.7),
(-0.1, 0.9),
]
)
data = tensor(test_data.astype(np.float32))
out = pred_fun(data)
pred_output = out.numpy()
pred_label = np.argmax(pred_output, 1)
with np.printoptions(precision=4, suppress=True):
print("Predicated probability:")
print(pred_output)
with mkstemp() as out:
pred_fun.dump(out, arg_names=["data"], output_names=["label"])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册