diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 011e55a54b2c7f0c9c4f3fe78ffd54a41da2ea97..09ec2c5704f532a3d44326afb8e807cb0966f26a 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -399,8 +399,6 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: if len(inps) == 1: return inps[0] - # FIXME: remove this convert_inputs - inps = convert_inputs(*inps, device=device) if device is None: device = get_device(inps) device = as_device(device) @@ -1168,9 +1166,10 @@ def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None): bcast_shape.append(shape[axis + 1 :]) target_shape.append(shape[axis + 1 :]) - out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape( - concat(target_shape) - ) + base_shape = astensor1d(base_shape) + bcast_shape = astensor1d(bcast_shape) + target_shape = astensor1d(target_shape) + out = broadcast_to(inp.reshape(base_shape), bcast_shape).reshape(target_shape) return out @@ -1191,9 +1190,10 @@ def _tile_one_dim(inp, rep, axis): if axis + 1 <= max_axis: target_shape.append(shape[axis + 1 :]) - out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape( - concat(target_shape) - ) + base_shape = astensor1d(base_shape) + bcast_shape = astensor1d(bcast_shape) + target_shape = astensor1d(target_shape) + out = broadcast_to(inp.reshape(base_shape), bcast_shape).reshape(target_shape) return out