提交 298c94f2 编写于 作者: T terrymanu

for #2567, move add gen key to InsertValue's logic to SQLRewriteEngine

上级 038232e9
......@@ -21,7 +21,6 @@ import com.google.common.base.Optional;
import org.apache.shardingsphere.core.exception.ShardingException;
import org.apache.shardingsphere.core.metadata.table.TableMetas;
import org.apache.shardingsphere.core.optimize.api.segment.InsertValue;
import org.apache.shardingsphere.core.optimize.sharding.constant.ShardingDerivedColumnType;
import org.apache.shardingsphere.core.optimize.sharding.engnie.ShardingOptimizeEngine;
import org.apache.shardingsphere.core.optimize.sharding.segment.condition.ShardingCondition;
import org.apache.shardingsphere.core.optimize.sharding.segment.condition.engine.InsertClauseShardingConditionEngine;
......@@ -34,7 +33,6 @@ import org.apache.shardingsphere.core.parse.sql.statement.dml.InsertStatement;
import org.apache.shardingsphere.core.rule.ShardingRule;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
......@@ -54,10 +52,8 @@ public final class ShardingInsertOptimizeEngine implements ShardingOptimizeEngin
List<String> columnNames = sqlStatement.useDefaultColumns() ? tableMetas.getAllColumnNames(tableName) : sqlStatement.getColumnNames();
Optional<GeneratedKey> generatedKey = GeneratedKey.getGenerateKey(shardingRule, parameters, sqlStatement, columnNames);
boolean isGeneratedValue = generatedKey.isPresent() && generatedKey.get().isGenerated();
Iterator<Comparable<?>> generatedValues = null;
if (isGeneratedValue) {
columnNames.remove(generatedKey.get().getColumnName());
generatedValues = generatedKey.get().getGeneratedValues().iterator();
}
List<String> allColumnNames = getAllColumnNames(columnNames, generatedKey.orNull(), shardingRule.getEncryptRule().getAssistedQueryAndPlainColumns(tableName));
List<ShardingCondition> shardingConditions = new InsertClauseShardingConditionEngine(shardingRule).createShardingConditions(sqlStatement, parameters, allColumnNames, generatedKey.orNull());
......@@ -68,9 +64,6 @@ public final class ShardingInsertOptimizeEngine implements ShardingOptimizeEngin
for (Collection<ExpressionSegment> each : sqlStatement.getAllValueExpressions()) {
InsertValue insertValue = new InsertValue(each, derivedColumnsCount, parameters, parametersOffset);
result.getInsertValues().add(insertValue);
if (isGeneratedValue) {
insertValue.appendValue(generatedValues.next(), ShardingDerivedColumnType.KEY_GEN);
}
parametersOffset += insertValue.getParametersCount();
}
return result;
......
......@@ -48,7 +48,7 @@ public final class GeneratedKey {
private final boolean generated;
private final List<Comparable<?>> generatedValues = new LinkedList<>();
private final LinkedList<Comparable<?>> generatedValues = new LinkedList<>();
/**
* Get generate key.
......
......@@ -168,40 +168,35 @@ public final class ShardingInsertOptimizeEngineTest {
public void assertInsertValuesWithPlaceholderWithGeneratedKey() {
ShardingInsertOptimizedStatement actual = new ShardingInsertOptimizeEngine().optimize(
shardingRule, mock(TableMetas.class), "", insertValuesParameters, insertValuesStatementWithPlaceholder);
assertThat(actual.getInsertValues().get(0).getParameters().size(), is(3));
assertThat(actual.getInsertValues().get(1).getParameters().size(), is(3));
assertThat(actual.getInsertValues().get(0).getParameters().size(), is(2));
assertThat(actual.getInsertValues().get(1).getParameters().size(), is(2));
assertThat(actual.getInsertValues().get(0).getParameters().get(0), CoreMatchers.<Object>is(10));
assertThat(actual.getInsertValues().get(0).getParameters().get(1), CoreMatchers.<Object>is("init"));
assertThat(actual.getInsertValues().get(0).getParameters().get(2), CoreMatchers.<Object>is(1));
assertThat(actual.getInsertValues().get(1).getParameters().get(0), CoreMatchers.<Object>is(11));
assertThat(actual.getInsertValues().get(1).getParameters().get(1), CoreMatchers.<Object>is("init"));
assertThat(actual.getInsertValues().get(1).getParameters().get(2), CoreMatchers.<Object>is(1));
}
@Test
public void assertInsertValuesWithPlaceholderWithGeneratedKeyWithEncrypt() {
ShardingInsertOptimizedStatement actual = new ShardingInsertOptimizeEngine().optimize(
shardingRule, mock(TableMetas.class), "", insertValuesParameters, insertValuesStatementWithPlaceholderWithEncrypt);
assertThat(actual.getInsertValues().get(1).getParameters().size(), is(3));
assertThat(actual.getInsertValues().get(0).getParameters().size(), is(2));
assertThat(actual.getInsertValues().get(1).getParameters().size(), is(2));
assertThat(actual.getInsertValues().get(0).getParameters().get(0), CoreMatchers.<Object>is(10));
assertThat(actual.getInsertValues().get(0).getParameters().get(1), CoreMatchers.<Object>is("init"));
assertThat(actual.getInsertValues().get(0).getParameters().get(2), CoreMatchers.<Object>is(1));
assertThat(actual.getInsertValues().get(1).getParameters().get(0), CoreMatchers.<Object>is(11));
assertThat(actual.getInsertValues().get(1).getParameters().get(1), CoreMatchers.<Object>is("init"));
assertThat(actual.getInsertValues().get(1).getParameters().get(2), CoreMatchers.<Object>is(1));
}
@Test
public void assertInsertValuesWithPlaceholderWithoutGeneratedKey() {
ShardingInsertOptimizedStatement actual = new ShardingInsertOptimizeEngine().optimize(shardingRule, mock(TableMetas.class), "", insertValuesParameters, insertValuesStatementWithPlaceholder);
assertThat(actual.getInsertValues().get(0).getParameters().size(), is(3));
assertThat(actual.getInsertValues().get(1).getParameters().size(), is(3));
assertThat(actual.getInsertValues().get(0).getParameters().size(), is(2));
assertThat(actual.getInsertValues().get(1).getParameters().size(), is(2));
assertThat(actual.getInsertValues().get(0).getParameters().get(0), CoreMatchers.<Object>is(10));
assertThat(actual.getInsertValues().get(0).getParameters().get(1), CoreMatchers.<Object>is("init"));
assertThat(actual.getInsertValues().get(0).getParameters().get(2), CoreMatchers.<Object>is(1));
assertThat(actual.getInsertValues().get(1).getParameters().get(0), CoreMatchers.<Object>is(11));
assertThat(actual.getInsertValues().get(1).getParameters().get(1), CoreMatchers.<Object>is("init"));
assertThat(actual.getInsertValues().get(1).getParameters().get(2), CoreMatchers.<Object>is(1));
}
@Test
......@@ -233,10 +228,9 @@ public final class ShardingInsertOptimizeEngineTest {
insertSetStatementWithPlaceholder.setSetAssignment(new SetAssignmentsSegment(0, 0, Arrays.asList(assignmentSegment1, assignmentSegment2)));
ShardingInsertOptimizedStatement actual = new ShardingInsertOptimizeEngine().optimize(
shardingRule, mock(TableMetas.class), "", insertSetParameters, insertSetStatementWithPlaceholder);
assertThat(actual.getInsertValues().get(0).getParameters().size(), is(3));
assertThat(actual.getInsertValues().get(0).getParameters().size(), is(2));
assertThat(actual.getInsertValues().get(0).getParameters().get(0), CoreMatchers.<Object>is(12));
assertThat(actual.getInsertValues().get(0).getParameters().get(1), CoreMatchers.<Object>is("a"));
assertThat(actual.getInsertValues().get(0).getParameters().get(2), CoreMatchers.<Object>is(1));
}
@Test
......@@ -246,10 +240,9 @@ public final class ShardingInsertOptimizeEngineTest {
insertSetStatementWithPlaceholderWithQueryEncrypt.setSetAssignment(new SetAssignmentsSegment(0, 0, Arrays.asList(assignmentSegment1, assignmentSegment2)));
ShardingInsertOptimizedStatement actual = new ShardingInsertOptimizeEngine().optimize(
shardingRule, mock(TableMetas.class), "", insertSetParameters, insertSetStatementWithPlaceholderWithQueryEncrypt);
assertThat(actual.getInsertValues().get(0).getParameters().size(), is(3));
assertThat(actual.getInsertValues().get(0).getParameters().size(), is(2));
assertThat(actual.getInsertValues().get(0).getParameters().get(0), CoreMatchers.<Object>is(12));
assertThat(actual.getInsertValues().get(0).getParameters().get(1), CoreMatchers.<Object>is("a"));
assertThat(actual.getInsertValues().get(0).getParameters().get(2), CoreMatchers.<Object>is(1));
}
@Test
......
......@@ -24,6 +24,9 @@ import org.apache.shardingsphere.core.optimize.api.statement.InsertOptimizedStat
import org.apache.shardingsphere.core.optimize.api.statement.OptimizedStatement;
import org.apache.shardingsphere.core.optimize.encrypt.constant.EncryptDerivedColumnType;
import org.apache.shardingsphere.core.optimize.encrypt.statement.EncryptOptimizedStatement;
import org.apache.shardingsphere.core.optimize.sharding.constant.ShardingDerivedColumnType;
import org.apache.shardingsphere.core.optimize.sharding.segment.insert.GeneratedKey;
import org.apache.shardingsphere.core.optimize.sharding.statement.dml.ShardingInsertOptimizedStatement;
import org.apache.shardingsphere.core.rewrite.builder.BaseParameterBuilder;
import org.apache.shardingsphere.core.rewrite.builder.InsertParameterBuilder;
import org.apache.shardingsphere.core.rewrite.builder.ParameterBuilder;
......@@ -44,6 +47,7 @@ import org.apache.shardingsphere.spi.encrypt.ShardingEncryptor;
import org.apache.shardingsphere.spi.encrypt.ShardingQueryAssistedEncryptor;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
......@@ -70,6 +74,7 @@ public final class SQLRewriteEngine {
final SQLRouteResult sqlRouteResult, final String sql, final List<Object> parameters, final boolean isSingleRoute, final boolean isQueryWithCipherColumn) {
baseRule = shardingRule;
optimizedStatement = sqlRouteResult.getShardingStatement();
processGeneratedKey();
encryptOptimizedStatement(shardingRule.getEncryptRule());
parameterBuilder = createParameterBuilder(parameters, sqlRouteResult);
sqlTokens = createSQLTokens(isSingleRoute, isQueryWithCipherColumn);
......@@ -93,6 +98,20 @@ public final class SQLRewriteEngine {
sqlBuilder = new SQLBuilder(sql, sqlTokens);
}
private void processGeneratedKey() {
if (optimizedStatement instanceof ShardingInsertOptimizedStatement) {
Optional<GeneratedKey> generatedKey = ((ShardingInsertOptimizedStatement) optimizedStatement).getGeneratedKey();
boolean isGeneratedValue = generatedKey.isPresent() && generatedKey.get().isGenerated();
if (!isGeneratedValue) {
return;
}
Iterator<Comparable<?>> generatedValues = generatedKey.get().getGeneratedValues().descendingIterator();
for (InsertValue each : ((ShardingInsertOptimizedStatement) optimizedStatement).getInsertValues()) {
each.appendValue(generatedValues.next(), ShardingDerivedColumnType.KEY_GEN);
}
}
}
private void encryptOptimizedStatement(final EncryptRule encryptRule) {
if (optimizedStatement instanceof InsertOptimizedStatement && !encryptRule.getEncryptTableNames().isEmpty()) {
encryptInsertOptimizedStatement(encryptRule, (InsertOptimizedStatement) optimizedStatement);
......
......@@ -226,15 +226,12 @@ public final class StandardRoutingEngine implements RoutingEngine {
private boolean match(final ShardingInsertOptimizedStatement insertOptimizedStatement, final InsertValue insertValue, final ShardingCondition shardingCondition) {
for (RouteValue each : shardingCondition.getRouteValues()) {
Object value = insertValue.getValue(getColumnIndex(insertOptimizedStatement.getColumnNames(), each));
Object value = insertOptimizedStatement.getGeneratedKey().isPresent() && insertOptimizedStatement.getGeneratedKey().get().isGenerated()
? insertOptimizedStatement.getGeneratedKey().get().getGeneratedValues().getLast() : insertValue.getValue(insertOptimizedStatement.getColumnNames().indexOf(each.getColumnName()));
if (!value.equals(((ListRouteValue) each).getValues().iterator().next())) {
return false;
}
}
return true;
}
private int getColumnIndex(final List<String> columnNames, final RouteValue routeValue) {
return columnNames.contains(routeValue.getColumnName()) ? columnNames.indexOf(routeValue.getColumnName()) : columnNames.size();
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册