diff --git a/drivers/infiniband/hw/mlx4/cq.c b/drivers/infiniband/hw/mlx4/cq.c
index 6d5f405912dd28c2bb25434614c06c3a52cb19dd..bf4f14a1b4fcbed8f0895fa56c43eeaf8ed68a5f 100644
--- a/drivers/infiniband/hw/mlx4/cq.c
+++ b/drivers/infiniband/hw/mlx4/cq.c
@@ -140,14 +140,18 @@ static int mlx4_ib_get_cq_umem(struct mlx4_ib_dev *dev, struct ib_ucontext *cont
 {
 	int err;
 	int cqe_size = dev->dev->caps.cqe_size;
+	int shift;
+	int n;
 
 	*umem = ib_umem_get(context, buf_addr, cqe * cqe_size,
 			    IB_ACCESS_LOCAL_WRITE, 1);
 	if (IS_ERR(*umem))
 		return PTR_ERR(*umem);
 
-	err = mlx4_mtt_init(dev->dev, ib_umem_page_count(*umem),
-			    (*umem)->page_shift, &buf->mtt);
+	n = ib_umem_page_count(*umem);
+	shift = mlx4_ib_umem_calc_optimal_mtt_size(*umem, 0, &n);
+	err = mlx4_mtt_init(dev->dev, n, shift, &buf->mtt);
+
 	if (err)
 		goto err_buf;
 
diff --git a/drivers/infiniband/hw/mlx4/mlx4_ib.h b/drivers/infiniband/hw/mlx4/mlx4_ib.h
index d09d4e7b27068d7354c2fa42457502ec06ec3478..719dae354066eb3bb43ac94dfd474803a3ee696e 100644
--- a/drivers/infiniband/hw/mlx4/mlx4_ib.h
+++ b/drivers/infiniband/hw/mlx4/mlx4_ib.h
@@ -935,5 +935,7 @@ struct ib_rwq_ind_table
 			      struct ib_rwq_ind_table_init_attr *init_attr,
 			      struct ib_udata *udata);
 int mlx4_ib_destroy_rwq_ind_table(struct ib_rwq_ind_table *wq_ind_table);
+int mlx4_ib_umem_calc_optimal_mtt_size(struct ib_umem *umem, u64 start_va,
+				       int *num_of_mtts);
 
 #endif /* MLX4_IB_H */
diff --git a/drivers/infiniband/hw/mlx4/mr.c b/drivers/infiniband/hw/mlx4/mr.c
index 8f408a02f6994da91ee52ad2600cc3538691685d..313bfb9ccb71a3b42ebacbaa9b98b91daca51a19 100644
--- a/drivers/infiniband/hw/mlx4/mr.c
+++ b/drivers/infiniband/hw/mlx4/mr.c
@@ -254,9 +254,8 @@ int mlx4_ib_umem_write_mtt(struct mlx4_ib_dev *dev, struct mlx4_mtt *mtt,
  * middle already handled as part of mtt shift calculation for both their start
  * & end addresses.
  */
-static int mlx4_ib_umem_calc_optimal_mtt_size(struct ib_umem *umem,
-					      u64 start_va,
-					      int *num_of_mtts)
+int mlx4_ib_umem_calc_optimal_mtt_size(struct ib_umem *umem, u64 start_va,
+				       int *num_of_mtts)
 {
 	u64 block_shift = MLX4_MAX_MTT_SHIFT;
 	u64 min_shift = umem->page_shift;
diff --git a/drivers/infiniband/hw/mlx4/qp.c b/drivers/infiniband/hw/mlx4/qp.c
index 26f3345948e2a58120fe890c103cd87c8bd64a14..f807a6278d44571631df042ed32b5b937fb3a991 100644
--- a/drivers/infiniband/hw/mlx4/qp.c
+++ b/drivers/infiniband/hw/mlx4/qp.c
@@ -1038,6 +1038,8 @@ static int create_qp_common(struct mlx4_ib_dev *dev, struct ib_pd *pd,
 			struct mlx4_ib_create_wq wq;
 		} ucmd;
 		size_t copy_len;
+		int shift;
+		int n;
 
 		copy_len = (src == MLX4_IB_QP_SRC) ?
 			   sizeof(struct mlx4_ib_create_qp) :
@@ -1100,8 +1102,10 @@ static int create_qp_common(struct mlx4_ib_dev *dev, struct ib_pd *pd,
 			goto err;
 		}
 
-		err = mlx4_mtt_init(dev->dev, ib_umem_page_count(qp->umem),
-				    qp->umem->page_shift, &qp->mtt);
+		n = ib_umem_page_count(qp->umem);
+		shift = mlx4_ib_umem_calc_optimal_mtt_size(qp->umem, 0, &n);
+		err = mlx4_mtt_init(dev->dev, n, shift, &qp->mtt);
+
 		if (err)
 			goto err_buf;