diff --git a/oneflow/python/framework/tensor.py b/oneflow/python/framework/tensor.py index db949a2df8e2492f4f4fa5290f2669729fa1f281..ecd1a5339ef54c85268de908267f8aa1028b5259 100644 --- a/oneflow/python/framework/tensor.py +++ b/oneflow/python/framework/tensor.py @@ -222,11 +222,14 @@ class Tensor: else: self._undetermined_tensor.requires_grad = requires_grad - def size(self): - return self.shape + def size(self, idx=None): + if idx is None: + return self.shape + else: + return self.shape[idx] - def dim(self, idx): - return self.shape[idx] + def dim(self): + return self.ndim def ndimension(self): return self.ndim diff --git a/oneflow/python/nn/init.py b/oneflow/python/nn/init.py index 67564da889850875f2d47061a87f14c8764a1897..738b2649c7afc515f716c6b42f11d09a0febaa87 100644 --- a/oneflow/python/nn/init.py +++ b/oneflow/python/nn/init.py @@ -79,8 +79,8 @@ def _calculate_fan_in_and_fan_out(tensor): "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" ) - num_input_fmaps = tensor.dim(1) - num_output_fmaps = tensor.dim(0) + num_input_fmaps = tensor.size(1) + num_output_fmaps = tensor.size(0) receptive_field_size = 1 if tensor.ndimension() > 2: # math.prod is not always available, accumulate the product manually