diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index a2b509a7644d039263cb35fced6399dcec724681..9bbc83866ab245b2559eba41bf73250485cd4a04 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -410,7 +410,7 @@ class ArrayMethodMixin(abc.ABC): def sum(self, axis=None, keepdims: bool = False): r"""Returns the sum of each row of the input tensor in the given dimension ``axis``. If ``axis`` is a list of axises, reduce over all of them. - + If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed(see :meth:`~.functional.tensor.remove_axis`). Same for prod/mean/max/min. diff --git a/imperative/python/megengine/functional/__init__.py b/imperative/python/megengine/functional/__init__.py index 6679366860f07416fe614a7d289c753635bece1b..56f26e6aa9abc0fa6678058a6383a421e951d448 100644 --- a/imperative/python/megengine/functional/__init__.py +++ b/imperative/python/megengine/functional/__init__.py @@ -8,7 +8,6 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # pylint: disable=redefined-builtin from .elemwise import * -from .graph import add_update from .loss import * from .math import * from .nn import * diff --git a/imperative/python/megengine/functional/graph.py b/imperative/python/megengine/functional/graph.py deleted file mode 100644 index 6af627eecfe8a387430a12c38dec87a24c85bc48..0000000000000000000000000000000000000000 --- a/imperative/python/megengine/functional/graph.py +++ /dev/null @@ -1,41 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import collections -from typing import Iterable, Optional, Union - -from ..tensor import Tensor - - -def add_update( - dest: Tensor, - delta: Tensor, - *, - alpha: Union[Tensor, float, int] = 1.0, - beta: Union[Tensor, float, int] = 1.0, - bias: Union[Tensor, float, int] = 0.0 -): - r"""Modify ``dest`` inplace as follows: - - .. math:: - dest = alpha * dest + beta * delta + bias - - :param dest: input data that will be inplace modified. - :param delta: update value that will be added to ``dest``. - :param alpha: weight ratio of ``dest``. Default: 1.0 - :param beta: weight ratio of ``delta``. Default: 1.0 - :param bias: bias value appended to the result. Default: 0.0 - """ - if beta is not None and beta != 1.0: - delta = delta * beta - if bias is not None and bias != 0.0: - delta = delta + bias - if alpha is not None and alpha != 1.0: - dest *= alpha - dest += delta - return dest diff --git a/imperative/python/megengine/module/qat/conv_bn.py b/imperative/python/megengine/module/qat/conv_bn.py index baa0d769ca2034e80f77a1f890d52bdbac13ea48..bb7414d90e63902fda9b02487400a0f335192b0a 100644 --- a/imperative/python/megengine/module/qat/conv_bn.py +++ b/imperative/python/megengine/module/qat/conv_bn.py @@ -5,7 +5,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from ...functional import add_update, ones, relu, sqrt, sum, zeros +from ...functional import ones, relu, sqrt, sum, zeros from ...quantization.utils import fake_quant_bias from .. import conv_bn as Float from .module import QATModule @@ -76,18 +76,10 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): bn_var.detach() * num_elements_per_channel / (num_elements_per_channel - 1) ) exponential_average_factor = 1 - self.bn.momentum - add_update( - self.bn.running_mean, - delta=bn_mean, - alpha=1 - exponential_average_factor, - beta=exponential_average_factor, - ) - add_update( - self.bn.running_var, - delta=bn_var, - alpha=1 - exponential_average_factor, - beta=exponential_average_factor, - ) + self.bn.running_mean *= self.bn.momentum + self.bn.running_mean += exponential_average_factor * bn_mean + self.bn.running_var *= self.bn.momentum + self.bn.running_var += exponential_average_factor * bn_var def calc_conv_bn_qat(self, inp, approx=True): if self.training and not approx: diff --git a/imperative/python/megengine/quantization/fake_quant.py b/imperative/python/megengine/quantization/fake_quant.py index f50e1bbfa2c700333fe85a1fbc83832cce8d4006..774a7cae93c9cfbaf19832ae0f8ba2788c3ef33b 100644 --- a/imperative/python/megengine/quantization/fake_quant.py +++ b/imperative/python/megengine/quantization/fake_quant.py @@ -127,7 +127,7 @@ class TQT(_FakeQuantize): # when disable, TQT will do normal forward, initialize scale weight tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) tmp_scale = F.log(tmp_scale / 127) / math.log(2) - F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) + self.scale[...] = tmp_scale return inp def get_qparams(self): diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 3dfce3fadbdead1fcea60870733fddf7b170f661..b18fbc1ee2d965c67c04e8b585301dc9dac5357b 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -290,41 +290,6 @@ def test_one_hot(): onehot_high_dimension() -def test_add_update(): - shape = (2, 3) - v = np.random.random(shape).astype(np.float32) - b = Tensor(v) - - u = F.add_update(b, 1) - np.testing.assert_allclose(u.numpy(), v + 1, atol=1e-6) - u = F.add_update(b, 1) - np.testing.assert_allclose(u.numpy(), v + 2, atol=1e-6) - - x = np.ones((2, 2), dtype=np.float32) - y = x * 0.5 - dest = tensor(x) - delta = tensor(y) - r = F.add_update(dest, delta, alpha=0.9, beta=0.1, bias=0.1) - np.testing.assert_allclose(r.numpy(), x * 0.9 + y * 0.1 + 0.1, atol=1e-6) - - -def test_add_update_params(): - b = np.random.random((2, 3)).astype(np.float32) - y = Tensor(b) - - # @jit.trace - def f(x): - return F.add_update(y, x) - - f(np.zeros((2, 3)).astype(np.float32)) - - z = Tensor(np.zeros((2, 3)).astype(np.float32)) - F.add_update(y, z, beta=0.1) - - res = f(np.ones((2, 3)).astype(np.float32)) - np.testing.assert_allclose(res.numpy(), b + 1) - - def test_binary_cross_entropy(): data1_shape = (2, 2) label1_shape = (2, 2)