提交 69673f14 编写于 作者: M Megvii Engine Team

fix(imperative): remove convert_inputs from concat

GitOrigin-RevId: 1511cb4b43766ada3e08ff763f69a3ea4411cc4e
上级 4d5faa3f
......@@ -399,8 +399,6 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
if len(inps) == 1:
return inps[0]
# FIXME: remove this convert_inputs
inps = convert_inputs(*inps, device=device)
if device is None:
device = get_device(inps)
device = as_device(device)
......@@ -1168,9 +1166,10 @@ def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None):
bcast_shape.append(shape[axis + 1 :])
target_shape.append(shape[axis + 1 :])
out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape(
concat(target_shape)
)
base_shape = astensor1d(base_shape)
bcast_shape = astensor1d(bcast_shape)
target_shape = astensor1d(target_shape)
out = broadcast_to(inp.reshape(base_shape), bcast_shape).reshape(target_shape)
return out
......@@ -1191,9 +1190,10 @@ def _tile_one_dim(inp, rep, axis):
if axis + 1 <= max_axis:
target_shape.append(shape[axis + 1 :])
out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape(
concat(target_shape)
)
base_shape = astensor1d(base_shape)
bcast_shape = astensor1d(bcast_shape)
target_shape = astensor1d(target_shape)
out = broadcast_to(inp.reshape(base_shape), bcast_shape).reshape(target_shape)
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册