diff --git a/python/paddle/vision/transforms/transforms.py b/python/paddle/vision/transforms/transforms.py index 1a3dbd68066a72384589ac24579e0540b5484a6e..9fd200bf0344d58d6a2705d768afffc7ce92dcc2 100644 --- a/python/paddle/vision/transforms/transforms.py +++ b/python/paddle/vision/transforms/transforms.py @@ -327,12 +327,17 @@ class ToTensor(BaseTransform): import paddle.vision.transforms as T import paddle.vision.transforms.functional as F - fake_img = Image.fromarray((np.random.rand(224, 224, 3) * 255.).astype(np.uint8)) + fake_img = Image.fromarray((np.random.rand(4, 5, 3) * 255.).astype(np.uint8)) transform = T.ToTensor() tensor = transform(fake_img) - + + print(tensor.shape) + # [3, 4, 5] + + print(tensor.dtype) + # paddle.float32 """ def __init__(self, data_format='CHW', keys=None):