FlinkRelMdUtil.scala 29.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

19
package org.apache.flink.table.planner.plan.utils
20

21
import org.apache.flink.table.data.binary.BinaryRowData
22 23 24 25
import org.apache.flink.table.planner.JDouble
import org.apache.flink.table.planner.calcite.FlinkRelBuilder.PlannerNamedWindowProperty
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.nodes.calcite.{Expand, Rank, WindowAggregate}
26
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecLocalSortWindowAggregate, BatchPhysicalGroupAggregateBase, BatchPhysicalLocalHashWindowAggregate, BatchPhysicalWindowAggregateBase}
27 28
import org.apache.flink.table.runtime.operators.rank.{ConstantRankRange, RankRange}
import org.apache.flink.table.runtime.operators.sort.BinaryIndexedSortable
29
import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer.LENGTH_SIZE_IN_BYTES
30

31 32 33
import com.google.common.collect.ImmutableList
import org.apache.calcite.avatica.util.TimeUnitRange._
import org.apache.calcite.plan.RelOptUtil
34
import org.apache.calcite.rel.core._
35 36
import org.apache.calcite.rel.metadata.{RelMdUtil, RelMetadataQuery}
import org.apache.calcite.rel.{RelNode, SingleRel}
37
import org.apache.calcite.rex._
38 39 40
import org.apache.calcite.sql.SqlKind
import org.apache.calcite.sql.`type`.SqlTypeName.{TIME, TIMESTAMP}
import org.apache.calcite.util.{ImmutableBitSet, NumberUtil}
41

42 43
import java.math.BigDecimal
import java.util
44

45
import scala.collection.JavaConversions._
46
import scala.collection.mutable
47

48
/**
49 50
 * FlinkRelMdUtil provides utility methods used by the metadata provider methods.
 */
51 52
object FlinkRelMdUtil {

53 54 55 56 57 58 59 60 61 62 63 64 65 66
  /** Returns an estimate of the number of rows returned by a SEMI/ANTI [[Join]]. */
  def getSemiAntiJoinRowCount(mq: RelMetadataQuery, left: RelNode, right: RelNode,
      joinType: JoinRelType, condition: RexNode, isAnti: Boolean): JDouble = {
    val leftCount = mq.getRowCount(left)
    if (leftCount == null) {
      return null
    }
    var selectivity = RexUtil.getSelectivity(condition)
    if (isAnti) {
      selectivity = 1d - selectivity
    }
    leftCount * selectivity
  }

67
  /**
68 69 70 71 72 73 74 75 76
   * Creates a RexNode that stores a selectivity value corresponding to the
   * selectivity of a semi-join/anti-join. This can be added to a filter to simulate the
   * effect of the semi-join/anti-join during costing, but should never appear in a real
   * plan since it has no physical implementation.
   *
   * @param mq  instance of metadata query
   * @param rel the SEMI/ANTI join of interest
   * @return constructed rexNode
   */
77 78
  def makeSemiAntiJoinSelectivityRexNode(mq: RelMetadataQuery, rel: Join): RexNode = {
    require(rel.getJoinType == JoinRelType.SEMI || rel.getJoinType == JoinRelType.ANTI)
79 80
    val joinInfo = rel.analyzeCondition()
    val rexBuilder = rel.getCluster.getRexBuilder
81 82
    makeSemiAntiJoinSelectivityRexNode(
      mq, joinInfo, rel.getLeft, rel.getRight, rel.getJoinType == JoinRelType.ANTI, rexBuilder)
83 84
  }

85
  private def makeSemiAntiJoinSelectivityRexNode(
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
      mq: RelMetadataQuery,
      joinInfo: JoinInfo,
      left: RelNode,
      right: RelNode,
      isAnti: Boolean,
      rexBuilder: RexBuilder): RexNode = {
    val equiSelectivity: JDouble = if (!joinInfo.leftKeys.isEmpty) {
      RelMdUtil.computeSemiJoinSelectivity(mq, left, right, joinInfo.leftKeys, joinInfo.rightKeys)
    } else {
      1D
    }

    val nonEquiSelectivity = RelMdUtil.guessSelectivity(joinInfo.getRemaining(rexBuilder))
    val semiJoinSelectivity = equiSelectivity * nonEquiSelectivity

    val selectivity = if (isAnti) {
      val antiJoinSelectivity = 1.0 - semiJoinSelectivity
      if (antiJoinSelectivity == 0.0) {
        // we don't expect that anti-join's selectivity is 0.0, so choose a default value 0.1
        0.1
      } else {
        antiJoinSelectivity
      }
    } else {
      semiJoinSelectivity
    }

    rexBuilder.makeCall(
      RelMdUtil.ARTIFICIAL_SELECTIVITY_FUNC,
      rexBuilder.makeApproxLiteral(new BigDecimal(selectivity)))
  }

  /**
119 120 121 122 123 124 125 126 127
   * Estimates new distinctRowCount of currentNode after it applies a condition.
   * The estimation based on one assumption:
   * even distribution of all distinct data
   *
   * @param rowCount         rowcount of node.
   * @param distinctRowCount distinct rowcount of node.
   * @param selectivity      selectivity of condition expression.
   * @return new distinctRowCount
   */
128 129 130 131 132 133 134 135 136
  def adaptNdvBasedOnSelectivity(
      rowCount: JDouble,
      distinctRowCount: JDouble,
      selectivity: JDouble): JDouble = {
    val ndv = Math.min(distinctRowCount, rowCount)
    Math.max((1 - Math.pow(1 - selectivity, rowCount / ndv)) * ndv, 1.0)
  }

  /**
137 138 139 140 141 142 143 144 145 146 147
   * Estimates ratio outputRowCount/ inputRowCount of agg when ndv of groupKeys is unavailable.
   *
   * the value of `1.0 - math.exp(-0.1 * groupCount)` increases with groupCount
   * from 0.095 until close to 1.0. when groupCount is 1, the formula result is 0.095,
   * when groupCount is 2, the formula result is 0.18,
   * when groupCount is 3, the formula result is 0.25.
   * ...
   *
   * @param groupingLength grouping keys length of aggregate
   * @return the ratio outputRowCount/ inputRowCount of agg when ndv of groupKeys is unavailable.
   */
148 149 150
  def getAggregationRatioIfNdvUnavailable(groupingLength: Int): JDouble =
    1.0 - math.exp(-0.1 * groupingLength)

151
  /**
152 153 154 155 156 157 158 159
   * Creates a RexNode that stores a selectivity value corresponding to the
   * selectivity of a NamedProperties predicate.
   *
   * @param winAgg  window aggregate node
   * @param predicate a RexNode
   * @return constructed rexNode including non-NamedProperties predicates and
   *         a predicate that stores NamedProperties predicate's selectivity
   */
160 161 162 163 164 165 166 167
  def makeNamePropertiesSelectivityRexNode(
      winAgg: WindowAggregate,
      predicate: RexNode): RexNode = {
    val fullGroupSet = AggregateUtil.checkAndGetFullGroupSet(winAgg)
    makeNamePropertiesSelectivityRexNode(winAgg, fullGroupSet, winAgg.getNamedProperties, predicate)
  }

  /**
168 169 170 171 172 173 174 175
   * Creates a RexNode that stores a selectivity value corresponding to the
   * selectivity of a NamedProperties predicate.
   *
   * @param globalWinAgg global window aggregate node
   * @param predicate a RexNode
   * @return constructed rexNode including non-NamedProperties predicates and
   *         a predicate that stores NamedProperties predicate's selectivity
   */
176
  def makeNamePropertiesSelectivityRexNode(
177
      globalWinAgg: BatchPhysicalWindowAggregateBase,
178 179
      predicate: RexNode): RexNode = {
    require(globalWinAgg.isFinal, "local window agg does not contain NamedProperties!")
180
    val fullGrouping = globalWinAgg.grouping ++ globalWinAgg.auxGrouping
181
    makeNamePropertiesSelectivityRexNode(
182
      globalWinAgg, fullGrouping, globalWinAgg.namedWindowProperties, predicate)
183 184 185
  }

  /**
186 187 188 189 190 191 192 193 194 195
   * Creates a RexNode that stores a selectivity value corresponding to the
   * selectivity of a NamedProperties predicate.
   *
   * @param winAgg window aggregate node
   * @param fullGrouping full groupSets
   * @param namedProperties NamedWindowProperty list
   * @param predicate a RexNode
   * @return constructed rexNode including non-NamedProperties predicates and
   *         a predicate that stores NamedProperties predicate's selectivity
   */
196 197 198
  def makeNamePropertiesSelectivityRexNode(
      winAgg: SingleRel,
      fullGrouping: Array[Int],
199
      namedProperties: Seq[PlannerNamedWindowProperty],
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
      predicate: RexNode): RexNode = {
    if (predicate == null || predicate.isAlwaysTrue || namedProperties.isEmpty) {
      return predicate
    }
    val rexBuilder = winAgg.getCluster.getRexBuilder
    val namePropertiesStartIdx = winAgg.getRowType.getFieldCount - namedProperties.size
    // split non-nameProperties predicates and nameProperties predicates
    val pushable = new util.ArrayList[RexNode]
    val notPushable = new util.ArrayList[RexNode]
    RelOptUtil.splitFilters(
      ImmutableBitSet.range(0, namePropertiesStartIdx),
      predicate,
      pushable,
      notPushable)
    if (notPushable.nonEmpty) {
      val pred = RexUtil.composeConjunction(rexBuilder, notPushable, true)
      val selectivity = RelMdUtil.guessSelectivity(pred)
      val fun = rexBuilder.makeCall(
        RelMdUtil.ARTIFICIAL_SELECTIVITY_FUNC,
        rexBuilder.makeApproxLiteral(new BigDecimal(selectivity)))
      pushable.add(fun)
    }
    RexUtil.composeConjunction(rexBuilder, pushable, true)
  }

225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
  /**
   * Returns the number of distinct values provided numSelected are selected
   * where there are domainSize distinct values.
   *
   * <p>Current implementation of RelMdUtil#numDistinctVals in Calcite 1.26
   * has precision problem, so we treat small and large inputs differently
   * here and handle large inputs with the old implementation of
   * RelMdUtil#numDistinctVals in Calcite 1.22.
   *
   * <p>This method should be removed once CALCITE-4351 is fixed. See CALCITE-4351
   * and FLINK-19780.
   */
  def numDistinctVals(domainSize: Double, numSelected: Double): Double = {
    val EPS = 1e-9
    if (Math.abs(1 / domainSize) < EPS) {
      // ln(1+x) ~= x for small x
      val dSize = RelMdUtil.capInfinity(domainSize)
      val numSel = RelMdUtil.capInfinity(numSelected)
      val res = if (dSize > 0) (1.0 - Math.exp(-1 * numSel / dSize)) * dSize else 0
      // fix the boundary cases
      Math.max(0, Math.min(res, Math.min(dSize, numSel)))
    } else {
      RelMdUtil.numDistinctVals(domainSize, numSelected)
    }
  }

251
  /**
252 253 254 255 256 257 258 259 260 261 262
   * Estimates outputRowCount of local aggregate.
   *
   * output rowcount of local agg is (1 - pow((1 - 1/x) , n/m)) * m * x, based on two assumption:
   * 1. even distribution of all distinct data
   * 2. even distribution of all data in each concurrent local agg worker
   *
   * @param parallelism       number of concurrent worker of local aggregate
   * @param inputRowCount     rowcount of input node of aggregate.
   * @param globalAggRowCount rowcount of output of global aggregate.
   * @return outputRowCount of local aggregate.
   */
263 264 265 266 267 268 269 270
  def getRowCountOfLocalAgg(
      parallelism: Int,
      inputRowCount: JDouble,
      globalAggRowCount: JDouble): JDouble =
    Math.min((1 - math.pow(1 - 1.0 / parallelism, inputRowCount / globalAggRowCount))
      * globalAggRowCount * parallelism, inputRowCount)

  /**
271 272 273 274 275 276
   * Takes a bitmap representing a set of input references and extracts the
   * ones that reference the group by columns in an aggregate.
   *
   * @param groupKey the original bitmap
   * @param aggRel   the aggregate
   */
277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
  def setAggChildKeys(
      groupKey: ImmutableBitSet,
      aggRel: Aggregate): (ImmutableBitSet, Array[AggregateCall]) = {
    val childKeyBuilder = ImmutableBitSet.builder
    val aggCalls = new mutable.ArrayBuffer[AggregateCall]()
    val groupSet = aggRel.getGroupSet.toArray
    val (auxGroupSet, otherAggCalls) = AggregateUtil.checkAndSplitAggCalls(aggRel)
    val fullGroupSet = groupSet ++ auxGroupSet
    // does not need to take keys in aggregate call into consideration if groupKey contains all
    // groupSet element in aggregate
    val containsAllAggGroupKeys = fullGroupSet.indices.forall(groupKey.get)
    groupKey.foreach(
      bit =>
        if (bit < fullGroupSet.length) {
          childKeyBuilder.set(fullGroupSet(bit))
        } else if (!containsAllAggGroupKeys) {
          // getIndicatorCount return 0 if auxGroupSet is not empty
          val agg = otherAggCalls.get(bit - (fullGroupSet.length + aggRel.getIndicatorCount))
          aggCalls += agg
        }
    )
    (childKeyBuilder.build(), aggCalls.toArray)
  }

  /**
302 303 304 305 306 307
   * Takes a bitmap representing a set of input references and extracts the
   * ones that reference the group by columns in an aggregate.
   *
   * @param groupKey the original bitmap
   * @param aggRel   the aggregate
   */
308 309
  def setAggChildKeys(
      groupKey: ImmutableBitSet,
310
      aggRel: BatchPhysicalGroupAggregateBase): (ImmutableBitSet, Array[AggregateCall]) = {
311 312 313 314
    require(!aggRel.isFinal || !aggRel.isMerge, "Cannot handle global agg which has local agg!")
    setChildKeysOfAgg(groupKey, aggRel)
  }

315
  /**
316 317 318 319 320 321
   * Takes a bitmap representing a set of input references and extracts the
   * ones that reference the group by columns in an aggregate.
   *
   * @param groupKey the original bitmap
   * @param aggRel   the aggregate
   */
322 323
  def setAggChildKeys(
      groupKey: ImmutableBitSet,
324
      aggRel: BatchPhysicalWindowAggregateBase): (ImmutableBitSet, Array[AggregateCall]) = {
325 326 327 328
    require(!aggRel.isFinal || !aggRel.isMerge, "Cannot handle global agg which has local agg!")
    setChildKeysOfAgg(groupKey, aggRel)
  }

329 330 331 332
  private def setChildKeysOfAgg(
      groupKey: ImmutableBitSet,
      agg: SingleRel): (ImmutableBitSet, Array[AggregateCall]) = {
    val (aggCalls, fullGroupSet) = agg match {
333 334 335
      case agg: BatchExecLocalSortWindowAggregate =>
        // grouping + assignTs + auxGrouping
        (agg.getAggCallList,
336
          agg.grouping ++ Array(agg.inputTimeFieldIndex) ++ agg.auxGrouping)
337
      case agg: BatchPhysicalLocalHashWindowAggregate =>
338 339
        // grouping + assignTs + auxGrouping
        (agg.getAggCallList,
340 341 342
          agg.grouping ++ Array(agg.inputTimeFieldIndex) ++ agg.auxGrouping)
      case agg: BatchPhysicalWindowAggregateBase =>
        (agg.getAggCallList, agg.grouping ++ agg.auxGrouping)
343 344
      case agg: BatchPhysicalGroupAggregateBase =>
        (agg.getAggCallList, agg.grouping ++ agg.auxGrouping)
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
      case _ => throw new IllegalArgumentException(s"Unknown aggregate: ${agg.getRelTypeName}")
    }
    // does not need to take keys in aggregate call into consideration if groupKey contains all
    // groupSet element in aggregate
    val containsAllAggGroupKeys = fullGroupSet.indices.forall(groupKey.get)
    val childKeyBuilder = ImmutableBitSet.builder
    val aggs = new mutable.ArrayBuffer[AggregateCall]()
    groupKey.foreach { bit =>
      if (bit < fullGroupSet.length) {
        childKeyBuilder.set(fullGroupSet(bit))
      } else if (!containsAllAggGroupKeys) {
        val agg = aggCalls.get(bit - fullGroupSet.length)
        aggs += agg
      }
    }
    (childKeyBuilder.build(), aggs.toArray)
  }

  /**
364 365 366 367 368 369 370 371 372 373
   * Takes a bitmap representing a set of local window aggregate references.
   *
   * global win-agg output type: groupSet + auxGroupSet + aggCall + namedProperties
   * local win-agg output type: groupSet + assignTs + auxGroupSet + aggCalls
   *
   * Skips `assignTs` when mapping `groupKey` to `childKey`.
   *
   * @param groupKey the original bitmap
   * @param globalWinAgg the global window aggregate
   */
374 375
  def setChildKeysOfWinAgg(
      groupKey: ImmutableBitSet,
376
      globalWinAgg: BatchPhysicalWindowAggregateBase): ImmutableBitSet = {
377 378 379
    require(globalWinAgg.isMerge, "Cannot handle global agg which does not have local window agg!")
    val childKeyBuilder = ImmutableBitSet.builder
    groupKey.toArray.foreach { key =>
380
      if (key < globalWinAgg.grouping.length) {
381 382 383 384 385 386 387 388 389 390
        childKeyBuilder.set(key)
      } else {
        // skips `assignTs`
        childKeyBuilder.set(key + 1)
      }
    }
    childKeyBuilder.build()
  }

  /**
391
   * Split groupKeys on Aggregate/ BatchPhysicalGroupAggregateBase/ BatchPhysicalWindowAggregateBase
392 393 394 395 396
   * into keys on aggregate's groupKey and aggregate's aggregateCalls.
   *
   * @param agg      the aggregate
   * @param groupKey the original bitmap
   */
397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419
  def splitGroupKeysOnAggregate(
      agg: SingleRel,
      groupKey: ImmutableBitSet): (ImmutableBitSet, Array[AggregateCall]) = {

    def removeAuxKey(
        groupKey: ImmutableBitSet,
        groupSet: Array[Int],
        auxGroupSet: Array[Int]): ImmutableBitSet = {
      if (groupKey.contains(ImmutableBitSet.of(groupSet: _*))) {
        // remove auxGroupSet from groupKey if groupKey contain both full-groupSet
        // and (partial-)auxGroupSet
        groupKey.except(ImmutableBitSet.of(auxGroupSet: _*))
      } else {
        groupKey
      }
    }

    agg match {
      case rel: Aggregate =>
        val (auxGroupSet, _) = AggregateUtil.checkAndSplitAggCalls(rel)
        val (childKeys, aggCalls) = setAggChildKeys(groupKey, rel)
        val childKeyExcludeAuxKey = removeAuxKey(childKeys, rel.getGroupSet.toArray, auxGroupSet)
        (childKeyExcludeAuxKey, aggCalls)
420
      case rel: BatchPhysicalGroupAggregateBase =>
421 422
        // set the bits as they correspond to the child input
        val (childKeys, aggCalls) = setAggChildKeys(groupKey, rel)
423
        val childKeyExcludeAuxKey = removeAuxKey(childKeys, rel.grouping, rel.auxGrouping)
424
        (childKeyExcludeAuxKey, aggCalls)
425
      case rel: BatchPhysicalWindowAggregateBase =>
426
        val (childKeys, aggCalls) = setAggChildKeys(groupKey, rel)
427
        val childKeyExcludeAuxKey = removeAuxKey(childKeys, rel.grouping, rel.auxGrouping)
428
        (childKeyExcludeAuxKey, aggCalls)
429 430 431 432 433
      case _ => throw new IllegalArgumentException(s"Unknown aggregate: ${agg.getRelTypeName}.")
    }
  }

  /**
434 435 436 437 438 439 440 441
   * Split a predicate on Aggregate into two parts, the first one is pushable part,
   * the second one is rest part.
   *
   * @param agg       Aggregate which to analyze
   * @param predicate Predicate which to analyze
   * @return a tuple, first element is pushable part, second element is rest part.
   *         Note, pushable condition will be converted based on the input field position.
   */
442 443 444 445 446 447 448 449
  def splitPredicateOnAggregate(
      agg: Aggregate,
      predicate: RexNode): (Option[RexNode], Option[RexNode]) = {
    val fullGroupSet = AggregateUtil.checkAndGetFullGroupSet(agg)
    splitPredicateOnAgg(fullGroupSet, agg, predicate)
  }

  /**
450 451 452 453 454 455 456 457
   * Split a predicate on BatchExecGroupAggregateBase into two parts,
   * the first one is pushable part, the second one is rest part.
   *
   * @param agg       Aggregate which to analyze
   * @param predicate Predicate which to analyze
   * @return a tuple, first element is pushable part, second element is rest part.
   *         Note, pushable condition will be converted based on the input field position.
   */
458
  def splitPredicateOnAggregate(
459
      agg: BatchPhysicalGroupAggregateBase,
460
      predicate: RexNode): (Option[RexNode], Option[RexNode]) = {
461
    splitPredicateOnAgg(agg.grouping ++ agg.auxGrouping, agg, predicate)
462 463
  }

464
  /**
465 466 467 468 469 470 471 472
   * Split a predicate on WindowAggregateBatchExecBase into two parts,
   * the first one is pushable part, the second one is rest part.
   *
   * @param agg       Aggregate which to analyze
   * @param predicate Predicate which to analyze
   * @return a tuple, first element is pushable part, second element is rest part.
   *         Note, pushable condition will be converted based on the input field position.
   */
473
  def splitPredicateOnAggregate(
474
      agg: BatchPhysicalWindowAggregateBase,
475
      predicate: RexNode): (Option[RexNode], Option[RexNode]) = {
476
    splitPredicateOnAgg(agg.grouping ++ agg.auxGrouping, agg, predicate)
477 478 479
  }

  /**
480 481 482 483 484 485 486 487 488
   * Shifts every [[RexInputRef]] in an expression higher than length of full grouping
   * (for skips `assignTs`).
   *
   * global win-agg output type: groupSet + auxGroupSet + aggCall + namedProperties
   * local win-agg output type: groupSet + assignTs + auxGroupSet + aggCalls
   *
   * @param predicate a RexNode
   * @param globalWinAgg the global window aggregate
   */
489 490
  def setChildPredicateOfWinAgg(
      predicate: RexNode,
491
      globalWinAgg: BatchPhysicalWindowAggregateBase): RexNode = {
492 493 494 495 496
    require(globalWinAgg.isMerge, "Cannot handle global agg which does not have local window agg!")
    if (predicate == null) {
      return null
    }
    // grouping + assignTs + auxGrouping
497
    val fullGrouping = globalWinAgg.grouping ++ globalWinAgg.auxGrouping
498 499 500 501
    // skips `assignTs`
    RexUtil.shift(predicate, fullGrouping.length, 1)
  }

502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542
  private def splitPredicateOnAgg(
      grouping: Array[Int],
      agg: SingleRel,
      predicate: RexNode): (Option[RexNode], Option[RexNode]) = {
    val notPushable = new util.ArrayList[RexNode]
    val pushable = new util.ArrayList[RexNode]
    val numOfGroupKey = grouping.length
    RelOptUtil.splitFilters(
      ImmutableBitSet.range(0, numOfGroupKey),
      predicate,
      pushable,
      notPushable)
    val rexBuilder = agg.getCluster.getRexBuilder
    val childPred = if (pushable.isEmpty) {
      None
    } else {
      // Converts a list of expressions that are based on the output fields of a
      // Aggregate to equivalent expressions on the Aggregate's input fields.
      val aggOutputFields = agg.getRowType.getFieldList
      val aggInputFields = agg.getInput.getRowType.getFieldList
      val adjustments = new Array[Int](aggOutputFields.size)
      grouping.zipWithIndex foreach {
        case (bit, index) => adjustments(index) = bit - index
      }
      val pushableConditions = pushable map {
        pushCondition =>
          pushCondition.accept(
            new RelOptUtil.RexInputConverter(
              rexBuilder,
              aggOutputFields,
              aggInputFields,
              adjustments))
      }
      Option(RexUtil.composeConjunction(rexBuilder, pushableConditions, true))
    }
    val restPred = if (notPushable.isEmpty) {
      None
    } else {
      Option(RexUtil.composeConjunction(rexBuilder, notPushable, true))
    }
    (childPred, restPred)
543 544
  }

545
  def binaryRowAverageSize(rel: RelNode): JDouble = {
546
    val binaryType = FlinkTypeFactory.toLogicalRowType(rel.getRowType)
547 548 549 550
    // TODO reuse FlinkRelMetadataQuery here
    val mq = rel.getCluster.getMetadataQuery
    val columnSizes = mq.getAverageColumnSizes(rel)
    var length = 0d
551
    columnSizes.zip(binaryType.getChildren).foreach {
552
      case (columnSize, internalType) =>
553
        if (BinaryRowData.isInFixedLengthPart(internalType)) {
554 555 556 557 558 559 560 561 562 563 564 565
          length += 8
        } else {
          if (columnSize == null) {
            // find a better way of computing generic type field variable-length
            // right now we use a small value assumption
            length += 16
          } else {
            // the 8 bytes is used store the length and offset of variable-length part.
            length += columnSize + 8
          }
        }
    }
566
    length += BinaryRowData.calculateBitSetWidthInBytes(columnSizes.size())
567 568 569
    length
  }

570
  def computeSortMemory(mq: RelMetadataQuery, inputOfSort: RelNode): JDouble = {
571 572 573 574 575
    //TODO It's hard to make sure that the normalized key's length is accurate in optimized stage.
    // use SortCodeGenerator.MAX_NORMALIZED_KEY_LEN instead of 16
    val normalizedKeyBytes = 16
    val rowCount = mq.getRowCount(inputOfSort)
    val averageRowSize = binaryRowAverageSize(inputOfSort)
576
    val recordAreaInBytes = rowCount * (averageRowSize + LENGTH_SIZE_IN_BYTES)
577 578 579 580
    val indexAreaInBytes = rowCount * (normalizedKeyBytes + BinaryIndexedSortable.OFFSET_LEN)
    recordAreaInBytes + indexAreaInBytes
  }

581 582 583
  def splitPredicateOnRank(
      rank: Rank,
      predicate: RexNode): (Option[RexNode], Option[RexNode]) = {
584
    val rankFunColumnIndex = RankUtil.getRankNumberColumnIndex(rank).getOrElse(-1)
585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615
    if (predicate == null || predicate.isAlwaysTrue || rankFunColumnIndex < 0) {
      return (Some(predicate), None)
    }

    val rankNodes = new util.ArrayList[RexNode]
    val nonRankNodes = new util.ArrayList[RexNode]
    RelOptUtil.splitFilters(
      ImmutableBitSet.range(0, rankFunColumnIndex),
      predicate,
      nonRankNodes,
      rankNodes)
    val rexBuilder = rank.getCluster.getRexBuilder
    val nonRankPred = if (nonRankNodes.isEmpty) {
      None
    } else {
      Option(RexUtil.composeConjunction(rexBuilder, nonRankNodes, true))
    }
    val rankPred = if (rankNodes.isEmpty) {
      None
    } else {
      Option(RexUtil.composeConjunction(rexBuilder, rankNodes, true))
    }
    (nonRankPred, rankPred)
  }

  def getRankRangeNdv(rankRange: RankRange): JDouble = rankRange match {
    case r: ConstantRankRange => (r.getRankEnd - r.getRankStart + 1).toDouble
    case _ => 100D // default value now
  }

  /**
616 617 618
   * Returns [[RexInputRef]] index set of projects corresponding to the given column index.
   * The index will be set as -1 if the given column in project is not a [[RexInputRef]].
   */
619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642
  def getInputRefIndices(index: Int, expand: Expand): util.Set[Int] = {
    val inputRefs = new util.HashSet[Int]()
    for (project <- expand.projects) {
      project.get(index) match {
        case inputRef: RexInputRef => inputRefs.add(inputRef.getIndex)
        case _ => inputRefs.add(-1)
      }
    }
    inputRefs
  }

  /** Splits a column set between left and right sets. */
  def splitColumnsIntoLeftAndRight(
      leftCount: Int,
      columns: ImmutableBitSet): (ImmutableBitSet, ImmutableBitSet) = {
    val leftBuilder = ImmutableBitSet.builder
    val rightBuilder = ImmutableBitSet.builder
    columns.foreach {
      bit => if (bit < leftCount) leftBuilder.set(bit) else rightBuilder.set(bit - leftCount)
    }
    (leftBuilder.build, rightBuilder.build)
  }

  /**
643 644 645 646 647 648 649 650
   * Computes the cardinality of a particular expression from the projection
   * list.
   *
   * @param mq   metadata query instance
   * @param calc calc RelNode
   * @param expr projection expression
   * @return cardinality
   */
651 652 653 654 655
  def cardOfCalcExpr(mq: RelMetadataQuery, calc: Calc, expr: RexNode): JDouble = {
    expr.accept(new CardOfCalcExpr(mq, calc))
  }

  /**
656 657 658 659 660 661 662
   * Visitor that walks over a scalar expression and computes the
   * cardinality of its result.
   * The code is borrowed from RelMdUtil
   *
   * @param mq   metadata query instance
   * @param calc calc relnode
   */
663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680
  private class CardOfCalcExpr(
      mq: RelMetadataQuery,
      calc: Calc)
    extends RexVisitorImpl[JDouble](true) {
    private val program = calc.getProgram

    private val condition = if (program.getCondition != null) {
      program.expandLocalRef(program.getCondition)
    } else {
      null
    }

    override def visitInputRef(inputRef: RexInputRef): JDouble = {
      val col = ImmutableBitSet.of(inputRef.getIndex)
      val distinctRowCount = mq.getDistinctRowCount(calc.getInput, col, condition)
      if (distinctRowCount == null) {
        null
      } else {
681
        FlinkRelMdUtil.numDistinctVals(distinctRowCount, mq.getAverageRowSize(calc))
682 683 684 685
      }
    }

    override def visitLiteral(literal: RexLiteral): JDouble = {
686
      FlinkRelMdUtil.numDistinctVals(1D, mq.getAverageRowSize(calc))
687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760
    }

    override def visitCall(call: RexCall): JDouble = {
      val rowCount = mq.getRowCount(calc)
      val distinctRowCount: JDouble = if (call.isA(SqlKind.MINUS_PREFIX)) {
        cardOfCalcExpr(mq, calc, call.getOperands.get(0))
      } else if (call.isA(ImmutableList.of(SqlKind.PLUS, SqlKind.MINUS))) {
        val card0 = cardOfCalcExpr(mq, calc, call.getOperands.get(0))
        if (card0 == null) {
          null
        } else {
          val card1 = cardOfCalcExpr(mq, calc, call.getOperands.get(1))
          if (card1 == null) {
            null
          } else {
            Math.max(card0, card1)
          }
        }
      } else if (call.isA(ImmutableList.of(SqlKind.TIMES, SqlKind.DIVIDE))) {
        NumberUtil.multiply(
          cardOfCalcExpr(mq, calc, call.getOperands.get(0)),
          cardOfCalcExpr(mq, calc, call.getOperands.get(1)))
      } else if (call.isA(SqlKind.EXTRACT)) {
        val extractUnit = call.getOperands.get(0)
        val timeOperand = call.getOperands.get(1)
        extractUnit match {
          // go https://www.postgresql.org/docs/9.1/static/
          // functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT to get the definitions of timeunits
          case unit: RexLiteral =>
            val unitValue = unit.getValue
            val timeOperandType = timeOperand.getType.getSqlTypeName
            // assume min time is 1970-01-01 00:00:00, max time is 2100-12-31 21:59:59
            unitValue match {
              case YEAR => 130D // [1970, 2100]
              case MONTH => 12D
              case DAY => 31D
              case HOUR => 24D
              case MINUTE => 60D
              case SECOND => timeOperandType match {
                case TIMESTAMP | TIME => 60 * 1000D // [0.000, 59.999]
                case _ => 60D // [0, 59]
              }
              case QUARTER => 4D
              case WEEK => 53D // [1, 53]
              case MILLISECOND => timeOperandType match {
                case TIMESTAMP | TIME => 60 * 1000D // [0.000, 59.999]
                case _ => 60D // [0, 59]
              }
              case MICROSECOND => timeOperandType match {
                case TIMESTAMP | TIME => 60 * 1000D * 1000D // [0.000, 59.999]
                case _ => 60D // [0, 59]
              }
              case DOW => 7D // [0, 6]
              case DOY => 366D // [1, 366]
              case EPOCH => timeOperandType match {
                // the number of seconds since 1970-01-01 00:00:00 UTC
                case TIMESTAMP | TIME => 130 * 24 * 60 * 60 * 1000D
                case _ => 130 * 24 * 60 * 60D
              }
              case DECADE => 13D // The year field divided by 10
              case CENTURY => 2D
              case MILLENNIUM => 2D
              case _ => cardOfCalcExpr(mq, calc, timeOperand)
            }
          case _ => cardOfCalcExpr(mq, calc, timeOperand)
        }
      } else if (call.getOperands.size() == 1) {
        cardOfCalcExpr(mq, calc, call.getOperands.get(0))
      } else {
        if (rowCount != null) rowCount / 10 else null
      }
      if (distinctRowCount == null) {
        null
      } else {
761
        FlinkRelMdUtil.numDistinctVals(distinctRowCount, rowCount)
762 763 764 765 766
      }
    }

  }

767
}