提交 9748aebe 编写于 作者: M Megvii Engine Team

refactor(mge): tensor_shape -> symbolic_shape

GitOrigin-RevId: 366dc048bfd7473a6bd148cb5d1ab70235aa43f1
上级 8acc3acf
......@@ -9,20 +9,20 @@
import os
_use_tensor_shape = False
if os.environ.get("MEGENGINE_USE_TENSOR_SHAPE"):
_use_tensor_shape = True
_use_symbolic_shape = False
if os.environ.get("MEGENGINE_USE_SYMBOLIC_SHAPE"):
_use_symbolic_shape = True
def use_tensor_shape() -> bool:
def use_symbolic_shape() -> bool:
"""Returns whether tensor.shape returns a tensor instead of a tuple
"""
return _use_tensor_shape
return _use_symbolic_shape
def set_tensor_shape(option: bool):
def set_symbolic_shape(option: bool):
""" Sets whether tensor.shape returns a tensor instead of a tuple
"""
global _use_tensor_shape
_use_tensor_shape = option
global _use_symbolic_shape
_use_symbolic_shape = option
......@@ -10,7 +10,7 @@ from typing import Iterable
import numpy as np
from .._trace_option import use_tensor_shape
from .._trace_option import use_symbolic_shape
from ..ops import builtin
from ..ops.special import Const
from .core import TensorBase, TensorWrapperBase, apply
......@@ -58,7 +58,7 @@ def check_bool_index(tensor, tuple_val):
)
)
i = i.reshape(-1)
if not use_tensor_shape():
if not use_symbolic_shape():
cur_shape = (
cur_shape[:idx]
+ (i.shape[0],)
......@@ -76,7 +76,7 @@ def check_bool_index(tensor, tuple_val):
offset += 1
tensor = tensor.reshape(cur_shape)
tdim += tot
if use_tensor_shape():
if use_symbolic_shape():
cur_shape = make_shape_tuple(cur_shape)
new_tuple_val.append(i)
else:
......
......@@ -11,7 +11,7 @@ import collections
import numpy as np
from .._trace_option import use_tensor_shape
from .._trace_option import use_symbolic_shape
from ..ops import builtin
from ..ops.builtin import GetVarShape
from ..ops.special import Const
......@@ -342,7 +342,7 @@ class ArrayMethodMixin(abc.ABC):
def __len__(self):
shape = self.shape
if use_tensor_shape():
if use_symbolic_shape():
shape = shape.numpy()
if shape:
return int(shape[0])
......@@ -372,7 +372,7 @@ class ArrayMethodMixin(abc.ABC):
@property
def size(self):
if use_tensor_shape():
if use_symbolic_shape():
return self.shape.prod()
return np.prod(self.shape).item()
......@@ -462,7 +462,7 @@ class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase):
@property
def shape(self):
if use_tensor_shape():
if use_symbolic_shape():
return apply(GetVarShape(), self)[0]
else:
return self.__wrapped__.shape
......
......@@ -19,7 +19,7 @@ import numpy as np
from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt.ops import OprAttr
from ..core._trace_option import set_tensor_shape
from ..core._trace_option import set_symbolic_shape
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G
from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
......@@ -121,7 +121,7 @@ class trace:
sublinear_memory_config: SublinearMemoryConfig = None,
profiling: bool = False,
opt_level: int = None,
tensor_shape: bool = True,
symbolic_shape: bool = True,
):
self.__wrapped__ = function
self._symbolic = symbolic
......@@ -130,7 +130,7 @@ class trace:
self._profiling = profiling
self._profiler = None
self._graph_opt_level = opt_level
self._tensor_shape = tensor_shape
self._symbolic_shape = symbolic_shape
self._reset()
......@@ -152,7 +152,7 @@ class trace:
self._output_bindings = None
self._output_names = None
set_tensor_shape(self._tensor_shape)
set_symbolic_shape(self._symbolic_shape)
def _new_handle(self):
handle = len(self._tinfo)
......
......@@ -18,7 +18,7 @@ import megengine as mge
import megengine.autodiff as ad
import megengine.functional as F
from megengine import jit
from megengine.core._trace_option import set_tensor_shape
from megengine.core._trace_option import set_symbolic_shape
from megengine.core.tensor.utils import make_shape_tuple
from megengine.functional.debug_param import set_conv_execution_strategy
from megengine.jit import SublinearMemoryConfig
......
......@@ -13,7 +13,7 @@ import pytest
import megengine.core.ops.builtin
import megengine.core.tensor.raw_tensor
from megengine.core._trace_option import use_tensor_shape
from megengine.core._trace_option import use_symbolic_shape
from megengine.core.ops._internal import all_ops
from megengine.core.tensor import Tensor
from megengine.core.tensor.core import apply
......@@ -532,7 +532,7 @@ def test_advance_indexing_with_bool():
np.testing.assert_equal(a, aa.numpy())
# XXX: trace does not expect empty condtake tensor
if not use_tensor_shape():
if not use_symbolic_shape():
a = np.ones((2, 2), dtype=np.int32)
b = np.array([[False, False], [False, False]])
aa = Tensor(a)
......
......@@ -17,7 +17,7 @@ import megengine.core.ops.builtin as builtin
import megengine.core.tensor.dtype as dtype
import megengine.functional as F
from megengine import Parameter, Tensor, is_cuda_available, tensor
from megengine.core._trace_option import use_tensor_shape
from megengine.core._trace_option import use_symbolic_shape
from megengine.core.autodiff.grad import Grad
from megengine.core.tensor.utils import make_shape_tuple
......
......@@ -15,7 +15,7 @@ from utils import opr_test
import megengine.functional as F
from megengine import tensor
from megengine.core._trace_option import use_tensor_shape
from megengine.core._trace_option import use_symbolic_shape
from megengine.core.tensor.utils import astensor1d
from megengine.distributed.helper import get_device_count_by_fork
......
......@@ -16,7 +16,7 @@ import pytest
import megengine as mge
import megengine.distributed as dist
from megengine import Tensor
from megengine.core._trace_option import use_tensor_shape
from megengine.core._trace_option import use_symbolic_shape
from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm
_assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6)
......
......@@ -15,7 +15,7 @@ import pytest
import megengine.core.tensor.megbrain_graph as G
import megengine.functional as F
from megengine import cgtools, tensor
from megengine.core._trace_option import set_tensor_shape
from megengine.core._trace_option import set_symbolic_shape
from megengine.core.ops import builtin as ops
from megengine.core.tensor.core import apply
from megengine.core.tensor.raw_tensor import as_raw_tensor
......@@ -238,7 +238,7 @@ def test_optimize_for_inference():
def test_optimize_for_inference_broadcast():
a = tensor(np.ones(1, dtype=np.float32))
@trace(capture_as_const=True, tensor_shape=True)
@trace(capture_as_const=True, symbolic_shape=True)
def f():
(b,) = apply(ops.Broadcast(), a, tensor([1, 10], dtype=np.int32))
return b
......@@ -248,7 +248,7 @@ def test_optimize_for_inference_broadcast():
def test_trace_cvt_bool():
set_tensor_shape(True)
set_symbolic_shape(True)
x = tensor([0], dtype=np.int32)
@trace(symbolic=True)
......@@ -261,7 +261,7 @@ def test_trace_cvt_bool():
def test_trace_reshape():
for symbolic in [False, True]:
set_tensor_shape(True)
set_symbolic_shape(True)
x1 = tensor(np.random.randn(2, 10, 10))
x2 = tensor(np.random.randn(4, 10, 10))
x3 = tensor(np.random.randn(8, 10, 10))
......@@ -344,7 +344,7 @@ def test_raise_on_trace():
def test_trace_broadcast():
for symbolic in [False, True]:
set_tensor_shape(True)
set_symbolic_shape(True)
x1 = tensor(np.random.randn(3, 1, 1))
x2 = tensor(np.random.randn(1, 4, 1))
x3 = tensor(np.random.randn(1, 1, 5))
......@@ -382,7 +382,7 @@ def test_trace_nms():
def test_trace_valid_broadcast():
set_tensor_shape(True)
set_symbolic_shape(True)
x1 = tensor(np.random.randn(1, 1))
x2 = tensor(np.random.randn(1, 2))
shape = (tensor([2]), tensor([2]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册