提交 0d4568d6 编写于 作者: M Megvii Engine Team

feat(mge): rename broadcast -> broadcast_to

GitOrigin-RevId: 82f46ad2c22d3b0e3e33783107825669487fc3a1
上级 f551d44e
...@@ -100,6 +100,8 @@ class GradManager: ...@@ -100,6 +100,8 @@ class GradManager:
:param ys: outputs of forward operators, e.g., the loss tensor :param ys: outputs of forward operators, e.g., the loss tensor
:param dys: derivatives of ys :param dys: derivatives of ys
""" """
from ..functional import ones_like
global backwarding_grad_manager global backwarding_grad_manager
cache = backwarding_grad_manager cache = backwarding_grad_manager
backwarding_grad_manager = self backwarding_grad_manager = self
...@@ -113,7 +115,7 @@ class GradManager: ...@@ -113,7 +115,7 @@ class GradManager:
if not isinstance(ys, (tuple, list)): if not isinstance(ys, (tuple, list)):
ys = [ys] ys = [ys]
if dys is None: if dys is None:
dys = [tensor(1.0).broadcast(y.shape) for y in ys] dys = [ones_like(y) for y in ys]
if not isinstance(dys, (tuple, list)): if not isinstance(dys, (tuple, list)):
dys = [dys] dys = [dys]
try: try:
......
...@@ -160,7 +160,7 @@ def subtensor_grad_fn(op, inputs, outputs, input_requires_grad): ...@@ -160,7 +160,7 @@ def subtensor_grad_fn(op, inputs, outputs, input_requires_grad):
def make_grad(grad_op, dy): def make_grad(grad_op, dy):
grad = ( grad = (
TensorWrapper(0, dtype=dy.dtype, device=dy.device) TensorWrapper(0, dtype=dy.dtype, device=dy.device)
.broadcast(TensorWrapper(input_shape)) ._broadcast(TensorWrapper(input_shape))
.__wrapped__ .__wrapped__
) )
(dx,) = apply(grad_op, grad, dy, *params) (dx,) = apply(grad_op, grad, dy, *params)
...@@ -186,7 +186,7 @@ def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad): ...@@ -186,7 +186,7 @@ def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad):
def make_grad(grad_op, dy): def make_grad(grad_op, dy):
grad = ( grad = (
TensorWrapper(0, dtype=dy.dtype, device=dy.device) TensorWrapper(0, dtype=dy.dtype, device=dy.device)
.broadcast(TensorWrapper(input_shape)) ._broadcast(TensorWrapper(input_shape))
.__wrapped__ .__wrapped__
) )
(dx,) = apply(grad_op, grad, dy, *params) (dx,) = apply(grad_op, grad, dy, *params)
......
...@@ -267,7 +267,7 @@ def setitem(tensor, index, value): ...@@ -267,7 +267,7 @@ def setitem(tensor, index, value):
value.shape, tmp_result.shape value.shape, tmp_result.shape
) )
) )
value = value.broadcast(tmp_result.shape) value = value._broadcast(tmp_result.shape)
if use_subtensor: if use_subtensor:
op = builtin.SetSubtensor(items=items) op = builtin.SetSubtensor(items=items)
else: else:
......
...@@ -396,7 +396,8 @@ class ArrayMethodMixin(abc.ABC): ...@@ -396,7 +396,8 @@ class ArrayMethodMixin(abc.ABC):
def reshape(self, *args): def reshape(self, *args):
return _reshape(self, _expand_args(args)) return _reshape(self, _expand_args(args))
def broadcast(self, *args): # FIXME: remove this method
def _broadcast(self, *args):
return _broadcast(self, _expand_args(args)) return _broadcast(self, _expand_args(args))
def transpose(self, *args): def transpose(self, *args):
......
...@@ -23,7 +23,16 @@ from .debug_param import get_conv_execution_strategy ...@@ -23,7 +23,16 @@ from .debug_param import get_conv_execution_strategy
from .distributed import all_reduce_sum from .distributed import all_reduce_sum
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu from .elemwise import exp, floor, log, log1p, maximum, minimum, relu
from .math import argsort, max, sum from .math import argsort, max, sum
from .tensor import broadcast, concat, expand_dims, full, ones, reshape, squeeze, zeros from .tensor import (
broadcast_to,
concat,
expand_dims,
full,
ones,
reshape,
squeeze,
zeros,
)
from .types import _pair, _pair_nonzero from .types import _pair, _pair_nonzero
__all__ = [ __all__ = [
...@@ -635,7 +644,7 @@ def batch_norm2d( ...@@ -635,7 +644,7 @@ def batch_norm2d(
def full_value(value): def full_value(value):
C = inp.shape[1] C = inp.shape[1]
(x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp)
return broadcast(x, [1, C, 1, 1]) return broadcast_to(x, [1, C, 1, 1])
def expand_or_full(x, value): def expand_or_full(x, value):
if x is None: if x is None:
...@@ -754,7 +763,7 @@ def sync_batch_norm( ...@@ -754,7 +763,7 @@ def sync_batch_norm(
if is_distributed(): if is_distributed():
# reduce all nodes' data to calculate mean and variance # reduce all nodes' data to calculate mean and variance
reduce_size = broadcast(Tensor(reduce_size, dtype=_dtype), [1] * _ndim) reduce_size = broadcast_to(Tensor(reduce_size, dtype=_dtype), [1] * _ndim)
stat = concat( stat = concat(
[reduce_size.astype(_dtype), channel_x1s, channel_x2s], axis=1 [reduce_size.astype(_dtype), channel_x1s, channel_x2s], axis=1
) )
...@@ -968,10 +977,10 @@ def matmul( ...@@ -968,10 +977,10 @@ def matmul(
if dim1 != dim2: if dim1 != dim2:
if dim1 < dim2: if dim1 < dim2:
shape1 = shape2[: dim2 - dim1] + shape1 shape1 = shape2[: dim2 - dim1] + shape1
inp1 = inp1.broadcast(*shape1) inp1 = broadcast_to(inp1, shape1)
else: else:
shape2 = shape1[: dim1 - dim2] + shape2 shape2 = shape1[: dim1 - dim2] + shape2
inp2 = inp2.broadcast(*shape2) inp2 = broadcast_to(inp2, shape2)
reshaped_batch_size = 1 reshaped_batch_size = 1
for i in shape1[:-2]: for i in shape1[:-2]:
reshaped_batch_size *= i reshaped_batch_size *= i
...@@ -986,9 +995,9 @@ def matmul( ...@@ -986,9 +995,9 @@ def matmul(
shp = shape1[:-1] + shape2[-1:] shp = shape1[:-1] + shape2[-1:]
elif dim1 == 3 or dim2 == 3: elif dim1 == 3 or dim2 == 3:
if dim2 < 3: if dim2 < 3:
inp2 = inp2.broadcast(*(inp1.shape[:1] + inp2.shape)) inp2 = broadcast_to(inp2, inp1.shape[:1] + inp2.shape)
elif dim1 < 3: elif dim1 < 3:
inp1 = inp1.broadcast(*(inp2.shape[:1] + inp1.shape)) inp1 = broadcast_to(inp1, inp2.shape[:1] + inp1.shape)
op = builtin.BatchedMatrixMul( op = builtin.BatchedMatrixMul(
transposeA=transpose_a, transposeA=transpose_a,
transposeB=transpose_b, transposeB=transpose_b,
...@@ -1205,7 +1214,7 @@ def interpolate( ...@@ -1205,7 +1214,7 @@ def interpolate(
[row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)],
axis=0, axis=0,
).reshape(1, 3, 3) ).reshape(1, 3, 3)
weight = broadcast(weight, (inp.shape[0], 3, 3)) weight = broadcast_to(weight, (inp.shape[0], 3, 3))
else: else:
hscale = 1.0 * ih / oh hscale = 1.0 * ih / oh
wscale = 1.0 * iw / ow wscale = 1.0 * iw / ow
...@@ -1221,7 +1230,7 @@ def interpolate( ...@@ -1221,7 +1230,7 @@ def interpolate(
[row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)],
axis=0, axis=0,
).reshape(1, 3, 3) ).reshape(1, 3, 3)
weight = broadcast(weight, (inp.shape[0], 3, 3)) weight = broadcast_to(weight, (inp.shape[0], 3, 3))
weight = weight.astype("float32") weight = weight.astype("float32")
ret = warp_perspective(inp, weight, dsize, interp_mode="LINEAR") ret = warp_perspective(inp, weight, dsize, interp_mode="LINEAR")
......
...@@ -19,7 +19,7 @@ from ..core.ops import builtin ...@@ -19,7 +19,7 @@ from ..core.ops import builtin
from ..core.ops._internal import param_defs as P from ..core.ops._internal import param_defs as P
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..core.tensor.tensor_wrapper import _remove_axis from ..core.tensor.tensor_wrapper import _broadcast, _remove_axis
from ..core.tensor.utils import ( from ..core.tensor.utils import (
astensor1d, astensor1d,
convert_inputs, convert_inputs,
...@@ -33,7 +33,7 @@ from .elemwise import ceil ...@@ -33,7 +33,7 @@ from .elemwise import ceil
__all__ = [ __all__ = [
"arange", "arange",
"broadcast", "broadcast_to",
"concat", "concat",
"cond_take", "cond_take",
"expand_dims", "expand_dims",
...@@ -104,7 +104,7 @@ def full(shape, value, dtype="float32", device=None): ...@@ -104,7 +104,7 @@ def full(shape, value, dtype="float32", device=None):
(x,) = Const(value, dtype=dtype, device=device)( (x,) = Const(value, dtype=dtype, device=device)(
Tensor(value, dtype=dtype, device=device) Tensor(value, dtype=dtype, device=device)
) )
return broadcast(x, shape) return broadcast_to(x, shape)
def ones(shape, dtype="float32", device=None): def ones(shape, dtype="float32", device=None):
...@@ -192,7 +192,7 @@ def identity(inp: Tensor) -> Tensor: ...@@ -192,7 +192,7 @@ def identity(inp: Tensor) -> Tensor:
return output return output
def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
""" """
Broadcasts a tensor to given shape. Broadcasts a tensor to given shape.
...@@ -209,7 +209,7 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: ...@@ -209,7 +209,7 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
import megengine.functional as F import megengine.functional as F
data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
out = F.broadcast(data, (4, 2, 3)) out = F.broadcast_to(data, (4, 2, 3))
print(out.numpy()) print(out.numpy())
Outputs: Outputs:
...@@ -229,7 +229,7 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: ...@@ -229,7 +229,7 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
[3. 4. 5.]]] [3. 4. 5.]]]
""" """
return inp.broadcast(shape) return _broadcast(inp, shape)
def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
...@@ -395,8 +395,7 @@ def _get_idx(index, axis): ...@@ -395,8 +395,7 @@ def _get_idx(index, axis):
0, index.shape[i] - 1, index.shape[i], device=index.device, 0, index.shape[i] - 1, index.shape[i], device=index.device,
) )
arange = ( arange = (
arange.reshape(*shape) broadcast_to(arange.reshape(*shape), index.shape)
.broadcast(index.shape)
.reshape(-1) .reshape(-1)
.astype(np.int32) .astype(np.int32)
) )
......
...@@ -15,7 +15,7 @@ from ..core.ops.builtin import Copy ...@@ -15,7 +15,7 @@ from ..core.ops.builtin import Copy
from ..core.tensor import Tensor from ..core.tensor import Tensor
from ..core.tensor.core import apply from ..core.tensor.core import apply
from .math import topk as _topk from .math import topk as _topk
from .tensor import transpose as _transpose from .tensor import broadcast_to, transpose
def accuracy( def accuracy(
...@@ -54,8 +54,8 @@ def accuracy( ...@@ -54,8 +54,8 @@ def accuracy(
_, pred = _topk(logits, k=max(topk), descending=True) _, pred = _topk(logits, k=max(topk), descending=True)
accs = [] accs = []
for k in topk: for k in topk:
correct = pred[:, :k].detach() == _transpose(target, (0, "x")).broadcast( correct = pred[:, :k].detach() == broadcast_to(
target.shape[0], k transpose(target, (0, "x")), (target.shape[0], k)
) )
accs.append(correct.astype(np.float32).sum() / target.shape[0]) accs.append(correct.astype(np.float32).sum() / target.shape[0])
if len(topk) == 1: # type: ignore[arg-type] if len(topk) == 1: # type: ignore[arg-type]
......
...@@ -319,7 +319,7 @@ def test_Broadcast(): ...@@ -319,7 +319,7 @@ def test_Broadcast():
x = TensorWrapper(x_np) x = TensorWrapper(x_np)
grad = Grad().wrt(x, callback=save_to(x)) grad = Grad().wrt(x, callback=save_to(x))
y = F.broadcast(x, (3, 3, 10)) y = F.broadcast_to(x, (3, 3, 10))
grad(y, F.ones_like(y)) grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy()) np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy())
......
...@@ -251,17 +251,17 @@ def test_broadcast(): ...@@ -251,17 +251,17 @@ def test_broadcast():
{"input": [data1, output1_shape], "output": output1_shape}, {"input": [data1, output1_shape], "output": output1_shape},
{"input": [data2, output2_shape], "output": output2_shape}, {"input": [data2, output2_shape], "output": output2_shape},
] ]
opr_test(cases, F.broadcast, compare_fn=compare_fn) opr_test(cases, F.broadcast_to, compare_fn=compare_fn)
x = F.ones((2, 1, 3)) x = F.ones((2, 1, 3))
with pytest.raises(ValueError): with pytest.raises(ValueError):
F.broadcast(x, (2, 3, 4)) F.broadcast_to(x, (2, 3, 4))
with pytest.raises(ValueError): with pytest.raises(ValueError):
F.broadcast(x, (4, 1, 3)) F.broadcast_to(x, (4, 1, 3))
with pytest.raises(ValueError): with pytest.raises(ValueError):
F.broadcast(x, (1, 3)) F.broadcast_to(x, (1, 3))
def test_utils_astensor1d(): def test_utils_astensor1d():
......
...@@ -351,7 +351,7 @@ def test_trace_broadcast(): ...@@ -351,7 +351,7 @@ def test_trace_broadcast():
@trace(symbolic=symbolic, capture_as_const=True) @trace(symbolic=symbolic, capture_as_const=True)
def f(x): def f(x):
y = x.broadcast((3, 4, 5)) y = F.broadcast_to(x, (3, 4, 5))
return y return y
f(x1) f(x1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册