GroupByMemoryResultSetMerger.java 6.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/*
 * Copyright 1999-2015 dangdang.com.
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * </p>
 */

18
package com.dangdang.ddframe.rdb.sharding.merger.groupby;
19

20
import com.dangdang.ddframe.rdb.sharding.merger.common.AbstractMemoryResultSetMerger;
21 22 23 24
import com.dangdang.ddframe.rdb.sharding.merger.common.MemoryResultSetRow;
import com.dangdang.ddframe.rdb.sharding.merger.groupby.aggregation.AggregationUnit;
import com.dangdang.ddframe.rdb.sharding.merger.groupby.aggregation.AggregationUnitFactory;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.selectitem.AggregationSelectItem;
T
terrymanu 已提交
25
import com.dangdang.ddframe.rdb.sharding.parsing.parser.statement.dql.select.SelectStatement;
26 27 28
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;
29 30 31 32 33 34 35 36 37

import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
38
import java.util.Map.Entry;
39 40 41 42 43 44

/**
 * 内存分组归并结果集接口.
 *
 * @author zhangliang
 */
45
public final class GroupByMemoryResultSetMerger extends AbstractMemoryResultSetMerger {
46
    
47
    private final SelectStatement selectStatement;
48
    
49
    private final Iterator<MemoryResultSetRow> memoryResultSetRows;
50
    
51
    public GroupByMemoryResultSetMerger(
T
terrymanu 已提交
52
            final Map<String, Integer> labelAndIndexMap, final List<ResultSet> resultSets, final SelectStatement selectStatement) throws SQLException {
53
        super(labelAndIndexMap);
54
        this.selectStatement = selectStatement;
55
        memoryResultSetRows = init(resultSets);
56 57
    }
    
58 59 60
    private Iterator<MemoryResultSetRow> init(final List<ResultSet> resultSets) throws SQLException {
        Map<GroupByValue, MemoryResultSetRow> dataMap = new HashMap<>(1024);
        Map<GroupByValue, Map<AggregationSelectItem, AggregationUnit>> aggregationMap = new HashMap<>(1024);
61 62
        for (ResultSet each : resultSets) {
            while (each.next()) {
63
                GroupByValue groupByValue = new GroupByValue(each, selectStatement.getGroupByItems());
64 65
                initForFirstGroupByValue(each, groupByValue, dataMap, aggregationMap);
                aggregate(each, groupByValue, aggregationMap);
66 67
            }
        }
68 69 70 71 72 73 74 75 76 77 78 79
        setAggregationValueToMemoryRow(dataMap, aggregationMap);
        List<MemoryResultSetRow> result = getMemoryResultSetRows(dataMap);
        if (!result.isEmpty()) {
            setCurrentResultSetRow(result.get(0));
        }
        return result.iterator();
    }
    
    private void initForFirstGroupByValue(final ResultSet resultSet, final GroupByValue groupByValue, final Map<GroupByValue, MemoryResultSetRow> dataMap, 
                                          final Map<GroupByValue, Map<AggregationSelectItem, AggregationUnit>> aggregationMap) throws SQLException {
        if (!dataMap.containsKey(groupByValue)) {
            dataMap.put(groupByValue, new MemoryResultSetRow(resultSet));
80
        }
81 82 83 84 85 86
        if (!aggregationMap.containsKey(groupByValue)) {
            Map<AggregationSelectItem, AggregationUnit> map = Maps.toMap(selectStatement.getAggregationSelectItems(), new Function<AggregationSelectItem, AggregationUnit>() {
                
                @Override
                public AggregationUnit apply(final AggregationSelectItem input) {
                    return AggregationUnitFactory.create(input.getType());
87
                }
88 89 90 91 92 93 94 95 96 97 98 99 100
            });
            aggregationMap.put(groupByValue, map);
        }
    }
    
    private void aggregate(final ResultSet resultSet, final GroupByValue groupByValue, final Map<GroupByValue, Map<AggregationSelectItem, AggregationUnit>> aggregationMap) throws SQLException {
        for (AggregationSelectItem each : selectStatement.getAggregationSelectItems()) {
            List<Comparable<?>> values = new ArrayList<>(2);
            if (each.getDerivedAggregationSelectItems().isEmpty()) {
                values.add(getAggregationValue(resultSet, each));
            } else {
                for (AggregationSelectItem derived : each.getDerivedAggregationSelectItems()) {
                    values.add(getAggregationValue(resultSet, derived));
101 102
                }
            }
103
            aggregationMap.get(groupByValue).get(each).merge(values);
104
        }
105 106
    }
    
107 108 109 110 111 112
    private Comparable<?> getAggregationValue(final ResultSet resultSet, final AggregationSelectItem aggregationSelectItem) throws SQLException {
        Object result = resultSet.getObject(aggregationSelectItem.getIndex());
        Preconditions.checkState(null == result || result instanceof Comparable, "Aggregation value must implements Comparable");
        return (Comparable<?>) result;
    }
    
113 114 115 116 117 118 119 120 121 122
    private void setAggregationValueToMemoryRow(final Map<GroupByValue, MemoryResultSetRow> dataMap, final Map<GroupByValue, Map<AggregationSelectItem, AggregationUnit>> aggregationMap) {
        for (Entry<GroupByValue, MemoryResultSetRow> entry : dataMap.entrySet()) {
            for (AggregationSelectItem each : selectStatement.getAggregationSelectItems()) {
                entry.getValue().setCell(each.getIndex(), aggregationMap.get(entry.getKey()).get(each).getResult());
            }
        }
    }
    
    private List<MemoryResultSetRow> getMemoryResultSetRows(final Map<GroupByValue, MemoryResultSetRow> dataMap) {
        List<MemoryResultSetRow> result = new ArrayList<>(dataMap.values());
T
terrymanu 已提交
123
        Collections.sort(result, new GroupByRowComparator(selectStatement));
124 125 126
        return result;
    }
    
127 128
    @Override
    public boolean next() throws SQLException {
129 130
        if (memoryResultSetRows.hasNext()) {
            setCurrentResultSetRow(memoryResultSetRows.next());
131 132 133 134 135
            return true;
        }
        return false;
    }
}