diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index d7772cb387cd826066c4854101393b7ef4616f6d..20151e7228ba7c0f0962d491a73094154b0177a3 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -4133,7 +4133,14 @@ def tensordot(a, b, axes, name=None): prod_axes = int(np.prod([shape_a[i] for i in axes])) perm = list(axes) + free if flipped else free + list(axes) new_shape = [prod_axes, prod_free] if flipped else [prod_free, prod_axes] - reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape) + if (perm != np.arange(len(shape_a))).any(): + a_trans = array_ops.transpose(a, perm) + else: + a_trans = a + if a_trans.get_shape().as_list() != new_shape: + reshaped_a = array_ops.reshape(a_trans, new_shape) + else: + reshaped_a = a_trans return reshaped_a, free_dims, free_dims else: if a.get_shape().ndims is not None and isinstance(axes, (list, tuple)): @@ -4209,7 +4216,12 @@ def tensordot(a, b, axes, name=None): b, b_axes, True) ab_matmul = matmul(a_reshape, b_reshape) if isinstance(a_free_dims, list) and isinstance(b_free_dims, list): - return array_ops.reshape(ab_matmul, a_free_dims + b_free_dims, name=name) + if (ab_matmul.get_shape().is_fully_defined() and + ab_matmul.get_shape().as_list() == a_free_dims + b_free_dims): + return ab_matmul + else: + return array_ops.reshape( + ab_matmul, a_free_dims + b_free_dims, name=name) else: a_free_dims = ops.convert_to_tensor(a_free_dims, dtype=dtypes.int32) b_free_dims = ops.convert_to_tensor(b_free_dims, dtype=dtypes.int32)