ShardingInsertStatementValidatorTest.java 10.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

18
package org.apache.shardingsphere.sharding.route.engine.validator.dml;
19

20
import com.google.common.collect.Lists;
21 22 23
import org.apache.shardingsphere.infra.binder.segment.table.TablesContext;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.statement.dml.InsertStatementContext;
24
import org.apache.shardingsphere.infra.exception.ShardingSphereException;
25
import org.apache.shardingsphere.infra.metadata.schema.ShardingSphereSchema;
26 27
import org.apache.shardingsphere.sharding.route.engine.validator.dml.impl.ShardingInsertStatementValidator;
import org.apache.shardingsphere.sharding.rule.ShardingRule;
T
tristaZero 已提交
28 29 30 31 32 33 34 35 36 37 38
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.InsertColumnsSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.OnDuplicateKeyColumnsSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionsSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
import org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
39
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLInsertStatement;
40
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLSelectStatement;
41 42 43 44 45
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;

46
import java.util.Collection;
47
import java.util.Collections;
48 49 50
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
51

52
import static org.mockito.Mockito.mock;
53 54
import static org.mockito.Mockito.when;

55
@RunWith(MockitoJUnitRunner.class)
56
public final class ShardingInsertStatementValidatorTest {
57 58 59
    
    @Mock
    private ShardingRule shardingRule;
L
Liang Zhang 已提交
60
    
61 62
    @Test(expected = ShardingSphereException.class)
    public void assertValidateInsertModifyMultiTables() {
63
        SQLStatementContext<InsertStatement> sqlStatementContext = new InsertStatementContext(
64
                new ShardingSphereSchema(Collections.emptyMap()), Collections.singletonList(1), createInsertStatement());
65 66 67
        Collection<String> shardingTableNames = Lists.newArrayList("order", "order_item");
        when(shardingRule.getShardingLogicTableNames(sqlStatementContext.getTablesContext().getTableNames())).thenReturn(shardingTableNames);
        when(shardingRule.isAllBindingTables(shardingTableNames)).thenReturn(true);
68
        new ShardingInsertStatementValidator().preValidate(shardingRule, sqlStatementContext, Collections.emptyList(), mock(ShardingSphereSchema.class));
69
    }
L
Liang Zhang 已提交
70
    
71 72 73
    @Test
    public void assertValidateOnDuplicateKeyWithoutShardingKey() {
        when(shardingRule.isShardingColumn("id", "user")).thenReturn(false);
74
        SQLStatementContext<InsertStatement> sqlStatementContext = new InsertStatementContext(
75 76
                new ShardingSphereSchema(Collections.emptyMap()), Collections.singletonList(1), createInsertStatement());
        new ShardingInsertStatementValidator().preValidate(shardingRule, sqlStatementContext, Collections.emptyList(), mock(ShardingSphereSchema.class));
77 78
    }
    
79
    @Test(expected = ShardingSphereException.class)
80 81
    public void assertValidateOnDuplicateKeyWithShardingKey() {
        when(shardingRule.isShardingColumn("id", "user")).thenReturn(true);
82
        SQLStatementContext<InsertStatement> sqlStatementContext = new InsertStatementContext(
83 84
                new ShardingSphereSchema(Collections.emptyMap()), Collections.singletonList(1), createInsertStatement());
        new ShardingInsertStatementValidator().preValidate(shardingRule, sqlStatementContext, Collections.emptyList(), mock(ShardingSphereSchema.class));
85 86 87 88 89 90
    }
    
    @Test(expected = ShardingSphereException.class)
    public void assertValidateInsertSelectWithoutKeyGenerateColumn() {
        when(shardingRule.findGenerateKeyColumnName("user")).thenReturn(Optional.of("id"));
        when(shardingRule.isGenerateKeyColumn("id", "user")).thenReturn(false);
91
        SQLStatementContext<InsertStatement> sqlStatementContext = new InsertStatementContext(
92
                new ShardingSphereSchema(Collections.emptyMap()), Collections.singletonList(1), createInsertSelectStatement());
93
        sqlStatementContext.getTablesContext().getTables().addAll(createSingleTablesContext().getTables());
94
        new ShardingInsertStatementValidator().preValidate(shardingRule, sqlStatementContext, Collections.emptyList(), mock(ShardingSphereSchema.class));
95 96 97 98 99 100
    }
    
    @Test
    public void assertValidateInsertSelectWithKeyGenerateColumn() {
        when(shardingRule.findGenerateKeyColumnName("user")).thenReturn(Optional.of("id"));
        when(shardingRule.isGenerateKeyColumn("id", "user")).thenReturn(true);
101
        SQLStatementContext<InsertStatement> sqlStatementContext = new InsertStatementContext(
102
                new ShardingSphereSchema(Collections.emptyMap()), Collections.singletonList(1), createInsertSelectStatement());
103
        sqlStatementContext.getTablesContext().getTables().addAll(createSingleTablesContext().getTables());
104
        new ShardingInsertStatementValidator().preValidate(shardingRule, sqlStatementContext, Collections.emptyList(), mock(ShardingSphereSchema.class));
105 106 107 108 109 110 111 112
    }
    
    @Test(expected = ShardingSphereException.class)
    public void assertValidateInsertSelectWithoutBindingTables() {
        when(shardingRule.findGenerateKeyColumnName("user")).thenReturn(Optional.of("id"));
        when(shardingRule.isGenerateKeyColumn("id", "user")).thenReturn(true);
        TablesContext multiTablesContext = createMultiTablesContext();
        when(shardingRule.isAllBindingTables(multiTablesContext.getTableNames())).thenReturn(false);
113
        SQLStatementContext<InsertStatement> sqlStatementContext = new InsertStatementContext(
114
                new ShardingSphereSchema(Collections.emptyMap()), Collections.singletonList(1), createInsertSelectStatement());
115
        sqlStatementContext.getTablesContext().getTables().addAll(multiTablesContext.getTables());
116
        new ShardingInsertStatementValidator().preValidate(shardingRule, sqlStatementContext, Collections.emptyList(), mock(ShardingSphereSchema.class));
117 118 119 120 121 122 123 124
    }
    
    @Test
    public void assertValidateInsertSelectWithBindingTables() {
        when(shardingRule.findGenerateKeyColumnName("user")).thenReturn(Optional.of("id"));
        when(shardingRule.isGenerateKeyColumn("id", "user")).thenReturn(true);
        TablesContext multiTablesContext = createMultiTablesContext();
        when(shardingRule.isAllBindingTables(multiTablesContext.getTableNames())).thenReturn(true);
125
        SQLStatementContext<InsertStatement> sqlStatementContext = new InsertStatementContext(
126
                new ShardingSphereSchema(Collections.emptyMap()), Collections.singletonList(1), createInsertSelectStatement());
127
        sqlStatementContext.getTablesContext().getTables().addAll(multiTablesContext.getTables());
128
        new ShardingInsertStatementValidator().preValidate(shardingRule, sqlStatementContext, Collections.emptyList(), mock(ShardingSphereSchema.class));
129 130
    }
    
131
    private InsertStatement createInsertStatement() {
132
        MySQLInsertStatement result = new MySQLInsertStatement();
133
        result.setTable(new SimpleTableSegment(0, 0, new IdentifierValue("user")));
134
        ColumnSegment columnSegment = new ColumnSegment(0, 0, new IdentifierValue("id"));
135
        AssignmentSegment assignmentSegment = new AssignmentSegment(0, 0, columnSegment, new ParameterMarkerExpressionSegment(0, 0, 1));
136
        result.setOnDuplicateKeyColumns(new OnDuplicateKeyColumnsSegment(0, 0, Collections.singletonList(assignmentSegment)));
137 138 139
        Collection<ColumnSegment> columns = new LinkedList<>();
        columns.add(columnSegment);
        result.setInsertColumns(new InsertColumnsSegment(0, 0, columns));
140 141 142 143
        return result;
    }

    private InsertStatement createInsertSelectStatement() {
144
        InsertStatement result = createInsertStatement();
145
        SelectStatement selectStatement = new MySQLSelectStatement();
146 147
        selectStatement.setProjections(new ProjectionsSegment(0, 0));
        result.setInsertSelect(new SubquerySegment(0, 0, selectStatement));
148 149
        return result;
    }
150 151 152 153 154 155 156 157 158 159 160 161 162
    
    private TablesContext createSingleTablesContext() {
        List<SimpleTableSegment> result = new LinkedList<>();
        result.add(new SimpleTableSegment(0, 0, new IdentifierValue("user")));
        return new TablesContext(result);
    }
    
    private TablesContext createMultiTablesContext() {
        List<SimpleTableSegment> result = new LinkedList<>();
        result.add(new SimpleTableSegment(0, 0, new IdentifierValue("user")));
        result.add(new SimpleTableSegment(0, 0, new IdentifierValue("order")));
        return new TablesContext(result);
    }
163
}