提交 b3889938 编写于 作者: M Megvii Engine Team

feat(mge/examples): add trace & dump example of cifar10 quantization

GitOrigin-RevId: cfc5e3483a66915e92458d1f502f1dde64ffe564
上级 67859f04
......@@ -332,6 +332,7 @@ class trace:
need_reset_nodes = self._need_reset_nodes = []
# links enforce ordering of I/O nodes
links = ()
readers = []
if self._capture_as_const:
for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()):
......@@ -345,7 +346,6 @@ class trace:
for op, ihandles, ohandles in self._seq:
ivars = []
readers = []
for h in ihandles:
info = self._tinfo[h]
if not hasattr(info, "varnode"):
......@@ -431,11 +431,19 @@ class trace:
if output_names and not isinstance(output_names, collections.Sequence):
output_names = (output_names,)
if output_names and len(output_names) != len(self._output_bindings):
raise ValueError("wrong number of output_names")
raise ValueError(
"wrong number of output_names, should be {} values".format(
len(self._output_bindings)
)
)
if arg_names and not isinstance(arg_names, collections.Sequence):
arg_names = (arg_names,)
if arg_names and len(arg_names) != len(self._arg_bindings):
raise ValueError("wrong number of arg_names")
raise ValueError(
"wrong number of arg_names, should be {} values".format(
len(self._arg_bindings)
)
)
output_names = output_names or self._output_names
h2v = {}
......
......@@ -118,8 +118,8 @@ class MinMaxObserver(Observer):
# stop gradient
x = x_orig.detach()
# find max and min
self.min_val = F.minimum(self.min_val, x.min())
self.max_val = F.maximum(self.max_val, x.max())
self.min_val.set_value(F.minimum(self.min_val, x.min()))
self.max_val.set_value(F.maximum(self.max_val, x.max()))
return x_orig
......@@ -144,11 +144,11 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
# stop gradient
x = x_orig.detach()
# Exponential Moving Average
self.min_val = (
self.min_val.set_value(
self.min_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * x.min()
)
self.max_val = (
self.max_val.set_value(
self.max_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * x.max()
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册