From 69673f14f736e56da0e266af3adeebb8867dcf54 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 24 Mar 2022 14:02:59 +0800 Subject: [PATCH] fix(imperative): remove convert_inputs from concat GitOrigin-RevId: 1511cb4b43766ada3e08ff763f69a3ea4411cc4e --- imperative/python/megengine/functional/tensor.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 011e55a54..09ec2c570 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -399,8 +399,6 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: if len(inps) == 1: return inps[0] - # FIXME: remove this convert_inputs - inps = convert_inputs(*inps, device=device) if device is None: device = get_device(inps) device = as_device(device) @@ -1168,9 +1166,10 @@ def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None): bcast_shape.append(shape[axis + 1 :]) target_shape.append(shape[axis + 1 :]) - out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape( - concat(target_shape) - ) + base_shape = astensor1d(base_shape) + bcast_shape = astensor1d(bcast_shape) + target_shape = astensor1d(target_shape) + out = broadcast_to(inp.reshape(base_shape), bcast_shape).reshape(target_shape) return out @@ -1191,9 +1190,10 @@ def _tile_one_dim(inp, rep, axis): if axis + 1 <= max_axis: target_shape.append(shape[axis + 1 :]) - out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape( - concat(target_shape) - ) + base_shape = astensor1d(base_shape) + bcast_shape = astensor1d(bcast_shape) + target_shape = astensor1d(target_shape) + out = broadcast_to(inp.reshape(base_shape), bcast_shape).reshape(target_shape) return out -- GitLab