提交 b3889938 编写于 作者: M Megvii Engine Team

feat(mge/examples): add trace & dump example of cifar10 quantization

GitOrigin-RevId: cfc5e3483a66915e92458d1f502f1dde64ffe564
上级 67859f04
...@@ -332,6 +332,7 @@ class trace: ...@@ -332,6 +332,7 @@ class trace:
need_reset_nodes = self._need_reset_nodes = [] need_reset_nodes = self._need_reset_nodes = []
# links enforce ordering of I/O nodes # links enforce ordering of I/O nodes
links = () links = ()
readers = []
if self._capture_as_const: if self._capture_as_const:
for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()): for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()):
...@@ -345,7 +346,6 @@ class trace: ...@@ -345,7 +346,6 @@ class trace:
for op, ihandles, ohandles in self._seq: for op, ihandles, ohandles in self._seq:
ivars = [] ivars = []
readers = []
for h in ihandles: for h in ihandles:
info = self._tinfo[h] info = self._tinfo[h]
if not hasattr(info, "varnode"): if not hasattr(info, "varnode"):
...@@ -431,11 +431,19 @@ class trace: ...@@ -431,11 +431,19 @@ class trace:
if output_names and not isinstance(output_names, collections.Sequence): if output_names and not isinstance(output_names, collections.Sequence):
output_names = (output_names,) output_names = (output_names,)
if output_names and len(output_names) != len(self._output_bindings): if output_names and len(output_names) != len(self._output_bindings):
raise ValueError("wrong number of output_names") raise ValueError(
"wrong number of output_names, should be {} values".format(
len(self._output_bindings)
)
)
if arg_names and not isinstance(arg_names, collections.Sequence): if arg_names and not isinstance(arg_names, collections.Sequence):
arg_names = (arg_names,) arg_names = (arg_names,)
if arg_names and len(arg_names) != len(self._arg_bindings): if arg_names and len(arg_names) != len(self._arg_bindings):
raise ValueError("wrong number of arg_names") raise ValueError(
"wrong number of arg_names, should be {} values".format(
len(self._arg_bindings)
)
)
output_names = output_names or self._output_names output_names = output_names or self._output_names
h2v = {} h2v = {}
......
...@@ -118,8 +118,8 @@ class MinMaxObserver(Observer): ...@@ -118,8 +118,8 @@ class MinMaxObserver(Observer):
# stop gradient # stop gradient
x = x_orig.detach() x = x_orig.detach()
# find max and min # find max and min
self.min_val = F.minimum(self.min_val, x.min()) self.min_val.set_value(F.minimum(self.min_val, x.min()))
self.max_val = F.maximum(self.max_val, x.max()) self.max_val.set_value(F.maximum(self.max_val, x.max()))
return x_orig return x_orig
...@@ -144,11 +144,11 @@ class ExponentialMovingAverageObserver(MinMaxObserver): ...@@ -144,11 +144,11 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
# stop gradient # stop gradient
x = x_orig.detach() x = x_orig.detach()
# Exponential Moving Average # Exponential Moving Average
self.min_val = ( self.min_val.set_value(
self.min_val * self.runtime_momentum self.min_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * x.min() + (1 - self.runtime_momentum) * x.min()
) )
self.max_val = ( self.max_val.set_value(
self.max_val * self.runtime_momentum self.max_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * x.max() + (1 - self.runtime_momentum) * x.max()
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册