提交 0d4568d6 编写于 作者: M Megvii Engine Team

feat(mge): rename broadcast -> broadcast_to

GitOrigin-RevId: 82f46ad2c22d3b0e3e33783107825669487fc3a1
上级 f551d44e
......@@ -100,6 +100,8 @@ class GradManager:
:param ys: outputs of forward operators, e.g., the loss tensor
:param dys: derivatives of ys
"""
from ..functional import ones_like
global backwarding_grad_manager
cache = backwarding_grad_manager
backwarding_grad_manager = self
......@@ -113,7 +115,7 @@ class GradManager:
if not isinstance(ys, (tuple, list)):
ys = [ys]
if dys is None:
dys = [tensor(1.0).broadcast(y.shape) for y in ys]
dys = [ones_like(y) for y in ys]
if not isinstance(dys, (tuple, list)):
dys = [dys]
try:
......
......@@ -160,7 +160,7 @@ def subtensor_grad_fn(op, inputs, outputs, input_requires_grad):
def make_grad(grad_op, dy):
grad = (
TensorWrapper(0, dtype=dy.dtype, device=dy.device)
.broadcast(TensorWrapper(input_shape))
._broadcast(TensorWrapper(input_shape))
.__wrapped__
)
(dx,) = apply(grad_op, grad, dy, *params)
......@@ -186,7 +186,7 @@ def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad):
def make_grad(grad_op, dy):
grad = (
TensorWrapper(0, dtype=dy.dtype, device=dy.device)
.broadcast(TensorWrapper(input_shape))
._broadcast(TensorWrapper(input_shape))
.__wrapped__
)
(dx,) = apply(grad_op, grad, dy, *params)
......
......@@ -267,7 +267,7 @@ def setitem(tensor, index, value):
value.shape, tmp_result.shape
)
)
value = value.broadcast(tmp_result.shape)
value = value._broadcast(tmp_result.shape)
if use_subtensor:
op = builtin.SetSubtensor(items=items)
else:
......
......@@ -396,7 +396,8 @@ class ArrayMethodMixin(abc.ABC):
def reshape(self, *args):
return _reshape(self, _expand_args(args))
def broadcast(self, *args):
# FIXME: remove this method
def _broadcast(self, *args):
return _broadcast(self, _expand_args(args))
def transpose(self, *args):
......
......@@ -23,7 +23,16 @@ from .debug_param import get_conv_execution_strategy
from .distributed import all_reduce_sum
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu
from .math import argsort, max, sum
from .tensor import broadcast, concat, expand_dims, full, ones, reshape, squeeze, zeros
from .tensor import (
broadcast_to,
concat,
expand_dims,
full,
ones,
reshape,
squeeze,
zeros,
)
from .types import _pair, _pair_nonzero
__all__ = [
......@@ -635,7 +644,7 @@ def batch_norm2d(
def full_value(value):
C = inp.shape[1]
(x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp)
return broadcast(x, [1, C, 1, 1])
return broadcast_to(x, [1, C, 1, 1])
def expand_or_full(x, value):
if x is None:
......@@ -754,7 +763,7 @@ def sync_batch_norm(
if is_distributed():
# reduce all nodes' data to calculate mean and variance
reduce_size = broadcast(Tensor(reduce_size, dtype=_dtype), [1] * _ndim)
reduce_size = broadcast_to(Tensor(reduce_size, dtype=_dtype), [1] * _ndim)
stat = concat(
[reduce_size.astype(_dtype), channel_x1s, channel_x2s], axis=1
)
......@@ -968,10 +977,10 @@ def matmul(
if dim1 != dim2:
if dim1 < dim2:
shape1 = shape2[: dim2 - dim1] + shape1
inp1 = inp1.broadcast(*shape1)
inp1 = broadcast_to(inp1, shape1)
else:
shape2 = shape1[: dim1 - dim2] + shape2
inp2 = inp2.broadcast(*shape2)
inp2 = broadcast_to(inp2, shape2)
reshaped_batch_size = 1
for i in shape1[:-2]:
reshaped_batch_size *= i
......@@ -986,9 +995,9 @@ def matmul(
shp = shape1[:-1] + shape2[-1:]
elif dim1 == 3 or dim2 == 3:
if dim2 < 3:
inp2 = inp2.broadcast(*(inp1.shape[:1] + inp2.shape))
inp2 = broadcast_to(inp2, inp1.shape[:1] + inp2.shape)
elif dim1 < 3:
inp1 = inp1.broadcast(*(inp2.shape[:1] + inp1.shape))
inp1 = broadcast_to(inp1, inp2.shape[:1] + inp1.shape)
op = builtin.BatchedMatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
......@@ -1205,7 +1214,7 @@ def interpolate(
[row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)],
axis=0,
).reshape(1, 3, 3)
weight = broadcast(weight, (inp.shape[0], 3, 3))
weight = broadcast_to(weight, (inp.shape[0], 3, 3))
else:
hscale = 1.0 * ih / oh
wscale = 1.0 * iw / ow
......@@ -1221,7 +1230,7 @@ def interpolate(
[row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)],
axis=0,
).reshape(1, 3, 3)
weight = broadcast(weight, (inp.shape[0], 3, 3))
weight = broadcast_to(weight, (inp.shape[0], 3, 3))
weight = weight.astype("float32")
ret = warp_perspective(inp, weight, dsize, interp_mode="LINEAR")
......
......@@ -19,7 +19,7 @@ from ..core.ops import builtin
from ..core.ops._internal import param_defs as P
from ..core.ops.special import Const
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..core.tensor.tensor_wrapper import _remove_axis
from ..core.tensor.tensor_wrapper import _broadcast, _remove_axis
from ..core.tensor.utils import (
astensor1d,
convert_inputs,
......@@ -33,7 +33,7 @@ from .elemwise import ceil
__all__ = [
"arange",
"broadcast",
"broadcast_to",
"concat",
"cond_take",
"expand_dims",
......@@ -104,7 +104,7 @@ def full(shape, value, dtype="float32", device=None):
(x,) = Const(value, dtype=dtype, device=device)(
Tensor(value, dtype=dtype, device=device)
)
return broadcast(x, shape)
return broadcast_to(x, shape)
def ones(shape, dtype="float32", device=None):
......@@ -192,7 +192,7 @@ def identity(inp: Tensor) -> Tensor:
return output
def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
"""
Broadcasts a tensor to given shape.
......@@ -209,7 +209,7 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
import megengine.functional as F
data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
out = F.broadcast(data, (4, 2, 3))
out = F.broadcast_to(data, (4, 2, 3))
print(out.numpy())
Outputs:
......@@ -229,7 +229,7 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
[3. 4. 5.]]]
"""
return inp.broadcast(shape)
return _broadcast(inp, shape)
def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
......@@ -395,8 +395,7 @@ def _get_idx(index, axis):
0, index.shape[i] - 1, index.shape[i], device=index.device,
)
arange = (
arange.reshape(*shape)
.broadcast(index.shape)
broadcast_to(arange.reshape(*shape), index.shape)
.reshape(-1)
.astype(np.int32)
)
......
......@@ -15,7 +15,7 @@ from ..core.ops.builtin import Copy
from ..core.tensor import Tensor
from ..core.tensor.core import apply
from .math import topk as _topk
from .tensor import transpose as _transpose
from .tensor import broadcast_to, transpose
def accuracy(
......@@ -54,8 +54,8 @@ def accuracy(
_, pred = _topk(logits, k=max(topk), descending=True)
accs = []
for k in topk:
correct = pred[:, :k].detach() == _transpose(target, (0, "x")).broadcast(
target.shape[0], k
correct = pred[:, :k].detach() == broadcast_to(
transpose(target, (0, "x")), (target.shape[0], k)
)
accs.append(correct.astype(np.float32).sum() / target.shape[0])
if len(topk) == 1: # type: ignore[arg-type]
......
......@@ -319,7 +319,7 @@ def test_Broadcast():
x = TensorWrapper(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = F.broadcast(x, (3, 3, 10))
y = F.broadcast_to(x, (3, 3, 10))
grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy())
......
......@@ -251,17 +251,17 @@ def test_broadcast():
{"input": [data1, output1_shape], "output": output1_shape},
{"input": [data2, output2_shape], "output": output2_shape},
]
opr_test(cases, F.broadcast, compare_fn=compare_fn)
opr_test(cases, F.broadcast_to, compare_fn=compare_fn)
x = F.ones((2, 1, 3))
with pytest.raises(ValueError):
F.broadcast(x, (2, 3, 4))
F.broadcast_to(x, (2, 3, 4))
with pytest.raises(ValueError):
F.broadcast(x, (4, 1, 3))
F.broadcast_to(x, (4, 1, 3))
with pytest.raises(ValueError):
F.broadcast(x, (1, 3))
F.broadcast_to(x, (1, 3))
def test_utils_astensor1d():
......
......@@ -351,7 +351,7 @@ def test_trace_broadcast():
@trace(symbolic=symbolic, capture_as_const=True)
def f(x):
y = x.broadcast((3, 4, 5))
y = F.broadcast_to(x, (3, 4, 5))
return y
f(x1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册