提交 14dbb414 编写于 作者: Z zhaojun12

Merge branch 'dev' of https://github.com/shardingjdbc/sharding-jdbc into dev

......@@ -34,6 +34,7 @@ import lombok.RequiredArgsConstructor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
......@@ -59,6 +60,7 @@ public final class InsertOptimizeEngine implements OptimizeEngine {
List<AndCondition> andConditions = insertStatement.getConditions().getOrCondition().getAndConditions();
List<InsertValue> insertValues = insertStatement.getInsertValues().getInsertValues();
List<ShardingCondition> result = new ArrayList<>(andConditions.size());
Iterator<Number> generatedKeys = null;
int count = 0;
for (AndCondition each : andConditions) {
InsertValue insertValue = insertValues.get(count);
......@@ -67,28 +69,36 @@ public final class InsertOptimizeEngine implements OptimizeEngine {
String logicTableName = insertStatement.getTables().getSingleTableName();
Optional<Column> generateKeyColumn = shardingRule.getGenerateKeyColumn(logicTableName);
String expression;
InsertShardingCondition insertShardingCondition;
if (-1 != insertStatement.getGenerateKeyColumnIndex() || !generateKeyColumn.isPresent()) {
expression = insertValue.getExpression();
insertShardingCondition = new InsertShardingCondition(insertValue.getExpression(), currentParameters);
} else {
if (null == generatedKeys) {
generatedKeys = generatedKey.getGeneratedKeys().iterator();
}
String expression;
Number currentGeneratedKey = generatedKeys.next();
if (0 == parameters.size()) {
expression = insertValue.getExpression().substring(0, insertValue.getExpression().length() - 1) + ", " + generatedKey.getGeneratedKeys().get(count).toString() + ")";
expression = insertValue.getExpression().substring(0, insertValue.getExpression().length() - 1) + ", " + currentGeneratedKey.toString() + ")";
} else {
expression = insertValue.getExpression().substring(0, insertValue.getExpression().length() - 1) + ", ?)";
currentParameters.add(generatedKey.getGeneratedKeys().get(count));
currentParameters.add(currentGeneratedKey);
}
insertShardingCondition = new InsertShardingCondition(expression, currentParameters);
insertShardingCondition.getShardingValues().add(getShardingCondition(generateKeyColumn.get(), currentGeneratedKey));
}
InsertShardingCondition insertShardingCondition = new InsertShardingCondition(expression, currentParameters);
insertShardingCondition.getShardingValues().addAll(getShardingCondition(each));
if (-1 == insertStatement.getGenerateKeyColumnIndex() && generateKeyColumn.isPresent()) {
insertShardingCondition.getShardingValues().add(getShardingCondition(generateKeyColumn.get(), generatedKey.getGeneratedKeys().get(count)));
}
result.add(insertShardingCondition);
count++;
}
return new ShardingConditions(result);
}
private ListShardingValue getShardingCondition(final Column column, final Number value) {
return new ListShardingValue<>(column.getTableName(), column.getName(),
new GeneratedKeyCondition(column, -1, value).getConditionValues(parameters));
}
private Collection<ListShardingValue> getShardingCondition(final AndCondition andCondition) {
Collection<ListShardingValue> result = new LinkedList<>();
for (Condition each : andCondition.getConditions()) {
......@@ -96,9 +106,4 @@ public final class InsertOptimizeEngine implements OptimizeEngine {
}
return result;
}
private ListShardingValue getShardingCondition(final Column column, final Number value) {
return new ListShardingValue<>(column.getTableName(), column.getName(),
new GeneratedKeyCondition(column, -1, value).getConditionValues(parameters));
}
}
......@@ -153,6 +153,7 @@ public final class SQLBuilder {
if (dataNode.getDataSourceName().equals(tableUnit.getDataSourceName()) && dataNode.getTableName().equals(tableUnit.getRoutingTables().iterator().next().getActualTableName())) {
expressions.add(shardingCondition.getInsertValueExpression());
parameters.addAll(shardingCondition.getParameters());
break;
}
}
}
......
......@@ -136,15 +136,15 @@ public final class StandardRoutingEngine implements RoutingEngine {
if (databaseShardingValues.isEmpty()) {
return availableTargetDatabases;
}
Collection<String> result = shardingRule.getDatabaseShardingStrategy(tableRule).doSharding(availableTargetDatabases, databaseShardingValues);
Collection<String> result = new LinkedHashSet<>(shardingRule.getDatabaseShardingStrategy(tableRule).doSharding(availableTargetDatabases, databaseShardingValues));
Preconditions.checkState(!result.isEmpty(), "no database route info");
return result;
}
private Collection<DataNode> routeTables(final TableRule tableRule, final String routedDataSource, final List<ShardingValue> tableShardingValues) {
Collection<String> availableTargetTables = tableRule.getActualTableNames(routedDataSource);
Collection<String> routedTables = tableShardingValues.isEmpty() ? availableTargetTables
: shardingRule.getTableShardingStrategy(tableRule).doSharding(availableTargetTables, tableShardingValues);
Collection<String> routedTables = new LinkedHashSet<>(tableShardingValues.isEmpty() ? availableTargetTables
: shardingRule.getTableShardingStrategy(tableRule).doSharding(availableTargetTables, tableShardingValues));
Preconditions.checkState(!routedTables.isEmpty(), "no table route info");
Collection<DataNode> result = new LinkedList<>();
for (String each : routedTables) {
......
......@@ -104,10 +104,10 @@ public final class InsertOptimizeEngineTest {
assertThat(((InsertShardingCondition) actual.getShardingConditions().get(1)).getInsertValueExpression(), is("(?, ?, ?)"));
assertThat(actual.getShardingConditions().get(0).getShardingValues().size(), is(2));
assertThat(actual.getShardingConditions().get(1).getShardingValues().size(), is(2));
assertShardingValue((ListShardingValue) actual.getShardingConditions().get(0).getShardingValues().get(0), 10);
assertShardingValue((ListShardingValue) actual.getShardingConditions().get(0).getShardingValues().get(1), 1);
assertShardingValue((ListShardingValue) actual.getShardingConditions().get(1).getShardingValues().get(0), 11);
assertShardingValue((ListShardingValue) actual.getShardingConditions().get(1).getShardingValues().get(1), 2);
assertShardingValue((ListShardingValue) actual.getShardingConditions().get(0).getShardingValues().get(0), 1);
assertShardingValue((ListShardingValue) actual.getShardingConditions().get(0).getShardingValues().get(1), 10);
assertShardingValue((ListShardingValue) actual.getShardingConditions().get(1).getShardingValues().get(0), 2);
assertShardingValue((ListShardingValue) actual.getShardingConditions().get(1).getShardingValues().get(1), 11);
}
@Test
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册