提交 4996340f 编写于 作者: T terrymanu

remove value for SQLPlaceholderExpr

上级 fedbd9e6
......@@ -116,7 +116,7 @@ public class SQLParser extends AbstractParser {
private SQLExpr getExpression(final String literals) {
if (equalAny(Symbol.QUESTION)) {
parametersIndex++;
return new SQLPlaceholderExpr(parametersIndex - 1, parameters.get(parametersIndex - 1));
return new SQLPlaceholderExpr(parametersIndex - 1);
}
if (equalAny(Literals.CHARS)) {
return new SQLTextExpr(literals);
......
......@@ -28,8 +28,8 @@ import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
......@@ -71,19 +71,19 @@ public final class ConditionContext {
* @param parameters 参数列表
*/
public void setNewConditionValue(final List<Object> parameters) {
for (Condition each : conditions.values()) {
if (each.getValueIndices().isEmpty()) {
continue;
}
for (int i = 0; i < each.getValueIndices().size(); i++) {
Object value = parameters.get(each.getValueIndices().get(i));
if (value instanceof Comparable<?>) {
each.getValues().set(i, (Comparable<?>) value);
} else {
each.getValues().set(i, "");
}
}
}
// for (Condition each : conditions.values()) {
// if (each.valueIndices.isEmpty()) {
// continue;
// }
// for (int i = 0; i < each.valueIndices.size(); i++) {
// Object value = parameters.get(each.valueIndices.get(i));
// if (value instanceof Comparable<?>) {
// each.values.set(i, (Comparable<?>) value);
// } else {
// each.values.set(i, "");
// }
// }
// }
}
/**
......@@ -100,9 +100,9 @@ public final class ConditionContext {
private final ShardingOperator operator;
private final List<Comparable<?>> values = new ArrayList<>();
private final List<Comparable<?>> values = new LinkedList<>();
private final List<Integer> valueIndices = new ArrayList<>();
private final List<Integer> valueIndices = new LinkedList<>();
public Condition(final ShardingColumnContext shardingColumnContext, final SQLExpr sqlExpr) {
this(shardingColumnContext, ShardingOperator.EQUAL);
......@@ -124,7 +124,6 @@ public final class ConditionContext {
private void initSQLExpr(final SQLExpr sqlExpr) {
if (sqlExpr instanceof SQLPlaceholderExpr) {
values.add((Comparable) ((SQLPlaceholderExpr) sqlExpr).getValue());
valueIndices.add(((SQLPlaceholderExpr) sqlExpr).getIndex());
} else if (sqlExpr instanceof SQLTextExpr) {
values.add(((SQLTextExpr) sqlExpr).getText());
......@@ -132,5 +131,22 @@ public final class ConditionContext {
values.add((Comparable) ((SQLNumberExpr) sqlExpr).getNumber());
}
}
/**
* 获取分片值.
*
* @param parameters 参数列表
* @return 分片值
*/
public List<Comparable<?>> getValues(final List<Object> parameters) {
List<Comparable<?>> result = new LinkedList<>(values);
for (int each : valueIndices) {
Object parameter = parameters.get(each);
if (parameter instanceof Comparable<?>) {
result.add((Comparable<?>) parameter);
}
}
return result;
}
}
}
......@@ -73,7 +73,7 @@ public final class MySQLInsertParser extends AbstractInsertParser {
} else if (getSqlParser().equalAny(DefaultKeyword.NULL)) {
sqlExpr = new SQLIgnoreExpr();
} else if (getSqlParser().equalAny(Symbol.QUESTION)) {
sqlExpr = new SQLPlaceholderExpr(getSqlParser().getParametersIndex(), getSqlParser().getParameters().get(getSqlParser().getParametersIndex()));
sqlExpr = new SQLPlaceholderExpr(getSqlParser().getParametersIndex());
getSqlParser().setParametersIndex(getSqlParser().getParametersIndex() + 1);
} else {
throw new UnsupportedOperationException("");
......
......@@ -30,6 +30,4 @@ import lombok.RequiredArgsConstructor;
public final class SQLPlaceholderExpr implements SQLExpr {
private final int index;
private final Object value;
}
......@@ -168,7 +168,7 @@ public abstract class AbstractInsertParser implements SQLStatementParser {
} else {
itemsToken.getItems().add("?");
parameters.add(autoIncrementedValue);
sqlExprs.add(new SQLPlaceholderExpr(parameters.size() - 1, autoIncrementedValue));
sqlExprs.add(new SQLPlaceholderExpr(parameters.size() - 1));
}
sqlContext.getGeneratedKeyContext().getColumns().add(each.getColumnName());
sqlContext.getGeneratedKeyContext().putValue(each.getColumnName(), autoIncrementedValue);
......
......@@ -125,7 +125,7 @@ public final class SQLRouteEngine {
SQLRouteResult routeSQL(final SQLContext sqlContext, final List<Object> parameters) {
Context context = MetricsContext.start("Route SQL");
SQLRouteResult result = new SQLRouteResult(sqlContext);
RoutingResult routingResult = routeSQL(sqlContext.getConditionContext(), sqlContext);
RoutingResult routingResult = routeSQL(sqlContext.getConditionContext(), sqlContext, parameters);
result.getExecutionUnits().addAll(routingResult.getSQLExecutionUnits(sqlContext.getSqlBuilder()));
amendSQLAccordingToRouteResult(sqlContext, parameters, result);
MetricsContext.stop(context);
......@@ -136,7 +136,7 @@ public final class SQLRouteEngine {
return result;
}
private RoutingResult routeSQL(final ConditionContext conditionContext, final SQLContext sqlContext) {
private RoutingResult routeSQL(final ConditionContext conditionContext, final SQLContext sqlContext, final List<Object> parameters) {
if (HintManagerHolder.isDatabaseShardingOnly()) {
return new DatabaseRouter(shardingRule.getDataSourceRule(), shardingRule.getDatabaseShardingStrategy(), sqlContext.getType()).route();
}
......@@ -148,13 +148,13 @@ public final class SQLRouteEngine {
}
}));
if (1 == logicTables.size()) {
return new SingleTableRouter(shardingRule, logicTables.iterator().next(), conditionContext, sqlContext.getType()).route();
return new SingleTableRouter(shardingRule, parameters, logicTables.iterator().next(), conditionContext, sqlContext.getType()).route();
}
if (shardingRule.isAllBindingTables(logicTables)) {
return new BindingTablesRouter(shardingRule, logicTables, conditionContext, sqlContext.getType()).route();
return new BindingTablesRouter(shardingRule, parameters, logicTables, conditionContext, sqlContext.getType()).route();
}
// TODO 可配置是否执行笛卡尔积
return new MixedTablesRouter(shardingRule, logicTables, conditionContext, sqlContext.getType()).route();
return new MixedTablesRouter(shardingRule, parameters, logicTables, conditionContext, sqlContext.getType()).route();
}
private void amendSQLAccordingToRouteResult(final SQLContext sqlContext, final List<Object> parameters, final SQLRouteResult sqlRouteResult) {
......
......@@ -27,6 +27,7 @@ import com.google.common.base.Preconditions;
import lombok.extern.slf4j.Slf4j;
import java.util.Collection;
import java.util.List;
/**
* Binding库表路由类.
......@@ -38,6 +39,8 @@ public class BindingTablesRouter {
private final ShardingRule shardingRule;
private final List<Object> parameters;
private final Collection<String> logicTables;
private final ConditionContext conditionContext;
......@@ -46,8 +49,9 @@ public class BindingTablesRouter {
private final SQLType sqlType;
public BindingTablesRouter(final ShardingRule shardingRule, final Collection<String> logicTables, final ConditionContext conditionContext, final SQLType sqlType) {
public BindingTablesRouter(final ShardingRule shardingRule, final List<Object> parameters, final Collection<String> logicTables, final ConditionContext conditionContext, final SQLType sqlType) {
this.shardingRule = shardingRule;
this.parameters = parameters;
this.logicTables = logicTables;
this.conditionContext = conditionContext;
this.sqlType = sqlType;
......@@ -65,7 +69,7 @@ public class BindingTablesRouter {
BindingRoutingResult result = null;
for (final String each : logicTables) {
if (null == result) {
result = new BindingRoutingResult(new SingleTableRouter(shardingRule, each, conditionContext, sqlType).route());
result = new BindingRoutingResult(new SingleTableRouter(shardingRule, parameters, each, conditionContext, sqlType).route());
} else {
result.bind(bindingTableRule, each);
}
......
......@@ -29,6 +29,7 @@ import lombok.extern.slf4j.Slf4j;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
/**
* 混合多库表路由类.
......@@ -42,6 +43,8 @@ public class MixedTablesRouter {
private final ShardingRule shardingRule;
private final List<Object> parameters;
private final Collection<String> logicTables;
private final ConditionContext conditionContext;
......@@ -59,11 +62,11 @@ public class MixedTablesRouter {
Collection<String> remainingTables = new ArrayList<>(logicTables);
Collection<SingleRoutingResult> result = new ArrayList<>(logicTables.size());
if (1 < bindingTables.size()) {
result.add(new BindingTablesRouter(shardingRule, bindingTables, conditionContext, sqlType).route());
result.add(new BindingTablesRouter(shardingRule, parameters, bindingTables, conditionContext, sqlType).route());
remainingTables.removeAll(bindingTables);
}
for (String each : remainingTables) {
SingleRoutingResult routingResult = new SingleTableRouter(shardingRule, each, conditionContext, sqlType).route();
SingleRoutingResult routingResult = new SingleTableRouter(shardingRule, parameters, each, conditionContext, sqlType).route();
if (null != routingResult) {
result.add(routingResult);
}
......
......@@ -38,10 +38,11 @@ public class SingleRouterUtil {
* 将条件对象转换为分片值对象.
*
* @param condition 条件对象
* @param parameters 参数列表
* @return 分片值对象
*/
public static ShardingValue<?> convertConditionToShardingValue(final ConditionContext.Condition condition) {
List<Comparable<?>> conditionValues = condition.getValues();
public static ShardingValue<?> convertConditionToShardingValue(final ConditionContext.Condition condition, final List<Object> parameters) {
List<Comparable<?>> conditionValues = condition.getValues(parameters);
switch (condition.getOperator()) {
case EQUAL:
case IN:
......
......@@ -53,6 +53,8 @@ public final class SingleTableRouter {
private final ShardingRule shardingRule;
private final List<Object> parameters;
private final String logicTable;
private final ConditionContext conditionContext;
......@@ -61,8 +63,9 @@ public final class SingleTableRouter {
private final SQLType sqlType;
public SingleTableRouter(final ShardingRule shardingRule, final String logicTable, final ConditionContext conditionContext, final SQLType sqlType) {
public SingleTableRouter(final ShardingRule shardingRule, final List<Object> parameters, final String logicTable, final ConditionContext conditionContext, final SQLType sqlType) {
this.shardingRule = shardingRule;
this.parameters = parameters;
this.logicTable = logicTable;
this.conditionContext = conditionContext;
this.sqlType = sqlType;
......@@ -158,7 +161,7 @@ public final class SingleTableRouter {
for (String each : shardingColumns) {
Optional<ConditionContext.Condition> condition = conditionContext.find(logicTable, each);
if (condition.isPresent()) {
result.add(SingleRouterUtil.convertConditionToShardingValue(condition.get()));
result.add(SingleRouterUtil.convertConditionToShardingValue(condition.get(), parameters));
}
}
return result;
......
......@@ -122,8 +122,7 @@ public abstract class AbstractDBUnitTest {
}
}
protected void assertDataSet(final String expectedDataSetFile, final Connection connection, final String actualTableName, final String sql)
throws SQLException, DatabaseUnitException {
protected void assertDataSet(final String expectedDataSetFile, final Connection connection, final String actualTableName, final String sql) throws SQLException, DatabaseUnitException {
try (Connection conn = connection) {
ITable actualTable = getConnection(conn).createQueryTable(actualTableName, sql);
IDataSet expectedDataSet = new FlatXmlDataSetBuilder().build(new InputStreamReader(AbstractDBUnitTest.class.getClassLoader().getResourceAsStream(expectedDataSetFile)));
......
......@@ -33,6 +33,7 @@ import java.util.List;
import static org.hamcrest.core.Is.is;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
public final class DeleteStatementParserTest extends AbstractStatementParserTest {
......@@ -54,11 +55,30 @@ public final class DeleteStatementParserTest extends AbstractStatementParserTest
"DELETE FROM TABLE_XXX xxx WHERE field4<10 AND TABLE_XXX.field1=1 AND field5>10 AND xxx.field2 IN (1,3) AND field6<=10 AND field3 BETWEEN 5 AND 20 AND field7>=10",
shardingRule, parameters);
DeleteSQLContext sqlContext = (DeleteSQLContext) statementParser.parseStatement();
assertDeleteStatement(sqlContext);
assertDeleteStatementWithoutParameter(sqlContext);
assertThat(sqlContext.getSqlBuilder().toString(), is(
"DELETE FROM [Token(TABLE_XXX)] xxx WHERE field4<10 AND [Token(TABLE_XXX)].field1=1 AND field5>10 AND xxx.field2 IN (1,3) AND field6<=10 AND field3 BETWEEN 5 AND 20 AND field7>=10"));
}
private void assertDeleteStatementWithoutParameter(final DeleteSQLContext sqlContext) {
assertThat(sqlContext.getTables().get(0).getName(), is("TABLE_XXX"));
assertThat(sqlContext.getTables().get(0).getAlias().get(), is("xxx"));
ConditionContext.Condition condition1 = sqlContext.getConditionContext().find("TABLE_XXX", "field1").get();
assertThat(condition1.getOperator(), is(ShardingOperator.EQUAL));
assertThat(condition1.getValues().size(), is(1));
assertThat(condition1.getValues().get(0), is((Comparable) 1));
ConditionContext.Condition condition2 = sqlContext.getConditionContext().find("TABLE_XXX", "field2").get();
assertThat(condition2.getOperator(), is(ShardingOperator.IN));
assertThat(condition2.getValues().size(), is(2));
assertThat(condition2.getValues().get(0), is((Comparable) 1));
assertThat(condition2.getValues().get(1), is((Comparable) 3));
ConditionContext.Condition condition3 = sqlContext.getConditionContext().find("TABLE_XXX", "field3").get();
assertThat(condition3.getOperator(), is(ShardingOperator.BETWEEN));
assertThat(condition3.getValues().size(), is(2));
assertThat(condition3.getValues().get(0), is((Comparable) 5));
assertThat(condition3.getValues().get(1), is((Comparable) 20));
}
@Test
public void parseWithParameter() throws SQLException {
ShardingRule shardingRule = createShardingRule();
......@@ -66,28 +86,31 @@ public final class DeleteStatementParserTest extends AbstractStatementParserTest
SQLParsingEngine statementParser = new SQLParsingEngine(DatabaseType.MySQL,
"DELETE FROM TABLE_XXX xxx WHERE field4<? AND field1=? AND field5>? AND field2 IN (?,?) AND field6<=? AND field3 BETWEEN ? AND ? AND field7>=?", shardingRule, parameters);
DeleteSQLContext sqlContext = (DeleteSQLContext) statementParser.parseStatement();
assertDeleteStatement(sqlContext);
assertDeleteStatementWithParameter(sqlContext);
assertThat(sqlContext.getSqlBuilder().toString(), is(
"DELETE FROM [Token(TABLE_XXX)] xxx WHERE field4<? AND field1=? AND field5>? AND field2 IN (?,?) AND field6<=? AND field3 BETWEEN ? AND ? AND field7>=?"));
}
private void assertDeleteStatement(final DeleteSQLContext sqlContext) {
private void assertDeleteStatementWithParameter(final DeleteSQLContext sqlContext) {
assertThat(sqlContext.getTables().get(0).getName(), is("TABLE_XXX"));
assertThat(sqlContext.getTables().get(0).getAlias().get(), is("xxx"));
ConditionContext.Condition condition1 = sqlContext.getConditionContext().find("TABLE_XXX", "field1").get();
assertThat(condition1.getOperator(), is(ShardingOperator.EQUAL));
assertThat(condition1.getValues().size(), is(1));
assertThat(condition1.getValues().get(0), is((Comparable) 1));
assertTrue(condition1.getValues().isEmpty());
assertThat(condition1.getValueIndices().size(), is(1));
assertThat(condition1.getValueIndices().get(0), is(1));
ConditionContext.Condition condition2 = sqlContext.getConditionContext().find("TABLE_XXX", "field2").get();
assertThat(condition2.getOperator(), is(ShardingOperator.IN));
assertThat(condition2.getValues().size(), is(2));
assertThat(condition2.getValues().get(0), is((Comparable) 1));
assertThat(condition2.getValues().get(1), is((Comparable) 3));
assertTrue(condition2.getValues().isEmpty());
assertThat(condition2.getValueIndices().size(), is(2));
assertThat(condition2.getValueIndices().get(0), is(3));
assertThat(condition2.getValueIndices().get(1), is(4));
ConditionContext.Condition condition3 = sqlContext.getConditionContext().find("TABLE_XXX", "field3").get();
assertThat(condition3.getOperator(), is(ShardingOperator.BETWEEN));
assertThat(condition3.getValues().size(), is(2));
assertThat(condition3.getValues().get(0), is((Comparable) 5));
assertThat(condition3.getValues().get(1), is((Comparable) 20));
assertTrue(condition3.getValues().isEmpty());
assertThat(condition3.getValueIndices().size(), is(2));
assertThat(condition3.getValueIndices().get(0), is(6));
assertThat(condition3.getValueIndices().get(1), is(7));
}
@Test(expected = UnsupportedOperationException.class)
......
......@@ -45,6 +45,7 @@ import java.util.Map;
import static org.hamcrest.core.Is.is;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
......@@ -56,7 +57,7 @@ public final class InsertStatementParserTest extends AbstractStatementParserTest
List<Object> parameters = Collections.emptyList();
SQLParsingEngine statementParser = new SQLParsingEngine(DatabaseType.MySQL, "INSERT INTO `TABLE_XXX` (`field1`, `field2`) VALUES (10, 1)", shardingRule, parameters);
InsertSQLContext sqlContext = (InsertSQLContext) statementParser.parseStatement();
assertInsertStatement(sqlContext);
assertInsertStatementWithoutParameter(sqlContext);
assertThat(sqlContext.getSqlBuilder().toString(), is("INSERT INTO [Token(TABLE_XXX)] (`field1`, `field2`) VALUES (10, 1)"));
}
......@@ -66,7 +67,7 @@ public final class InsertStatementParserTest extends AbstractStatementParserTest
List<Object> parameters = Lists.<Object>newArrayList(10, 1);
SQLParsingEngine statementParser = new SQLParsingEngine(DatabaseType.MySQL, "INSERT INTO TABLE_XXX (field1, field2) VALUES (?, ?)", shardingRule, parameters);
InsertSQLContext sqlContext = (InsertSQLContext) statementParser.parseStatement();
assertInsertStatement(sqlContext);
assertInsertStatementWithParameter(sqlContext);
assertThat(sqlContext.getSqlBuilder().toString(), is("INSERT INTO [Token(TABLE_XXX)] (field1, field2) VALUES (?, ?)"));
}
......@@ -76,7 +77,7 @@ public final class InsertStatementParserTest extends AbstractStatementParserTest
List<Object> parameters = Collections.emptyList();
SQLParsingEngine statementParser = new SQLParsingEngine(DatabaseType.MySQL, "INSERT INTO `TABLE_XXX` (`field1`) VALUES (10)", shardingRule, parameters);
InsertSQLContext sqlContext = (InsertSQLContext) statementParser.parseStatement();
assertInsertStatement(sqlContext);
assertInsertStatementWithoutParameter(sqlContext);
assertThat(sqlContext.getSqlBuilder().toString(), is("INSERT INTO [Token(TABLE_XXX)] (`field1`, field2) VALUES (10, 1)"));
}
......@@ -86,11 +87,11 @@ public final class InsertStatementParserTest extends AbstractStatementParserTest
List<Object> parameters = Lists.<Object>newArrayList(10);
SQLParsingEngine statementParser = new SQLParsingEngine(DatabaseType.MySQL, "INSERT INTO `TABLE_XXX` (`field1`) VALUES (?)", shardingRule, parameters);
InsertSQLContext sqlContext = (InsertSQLContext) statementParser.parseStatement();
assertInsertStatement(sqlContext);
assertInsertStatementWithParameter(sqlContext);
assertThat(sqlContext.getSqlBuilder().toString(), is("INSERT INTO [Token(TABLE_XXX)] (`field1`, field2) VALUES (?, ?)"));
}
private void assertInsertStatement(final InsertSQLContext sqlContext) {
private void assertInsertStatementWithoutParameter(final InsertSQLContext sqlContext) {
assertThat(sqlContext.getTables().get(0).getName(), is("TABLE_XXX"));
ConditionContext.Condition condition1 = sqlContext.getConditionContext().find("TABLE_XXX", "field1").get();
assertThat(condition1.getOperator(), is(ShardingOperator.EQUAL));
......@@ -104,6 +105,22 @@ public final class InsertStatementParserTest extends AbstractStatementParserTest
assertThat(condition2.getValues().get(0), is((Comparable) 1));
}
private void assertInsertStatementWithParameter(final InsertSQLContext sqlContext) {
assertThat(sqlContext.getTables().get(0).getName(), is("TABLE_XXX"));
ConditionContext.Condition condition1 = sqlContext.getConditionContext().find("TABLE_XXX", "field1").get();
assertThat(condition1.getOperator(), is(ShardingOperator.EQUAL));
assertTrue(condition1.getValues().isEmpty());
assertThat(condition1.getValueIndices().size(), is(1));
assertThat(condition1.getValueIndices().get(0), is(0));
ConditionContext.Condition condition2 = sqlContext.getConditionContext().find("TABLE_XXX", "field2").get();
assertThat(condition2.getShardingColumnContext().getColumnName(), is("field2"));
assertThat(condition2.getShardingColumnContext().getTableName(), is("TABLE_XXX"));
assertThat(condition2.getOperator(), is(ShardingOperator.EQUAL));
assertTrue(condition2.getValues().isEmpty());
assertThat(condition2.getValueIndices().size(), is(1));
assertThat(condition2.getValueIndices().get(0), is(1));
}
private ShardingRule createShardingRuleWithAutoIncrementColumns() {
DataSource dataSource = mock(DataSource.class);
Connection connection = mock(Connection.class);
......
......@@ -34,6 +34,7 @@ import java.util.List;
import static org.hamcrest.core.Is.is;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
public final class UpdateStatementParserTest extends AbstractStatementParserTest {
......@@ -54,12 +55,31 @@ public final class UpdateStatementParserTest extends AbstractStatementParserTest
SQLParsingEngine statementParser = new SQLParsingEngine(DatabaseType.MySQL, "UPDATE TABLE_XXX xxx SET TABLE_XXX.field1=field1+1,xxx.field2=2 WHERE TABLE_XXX.field4<10 AND"
+ " TABLE_XXX.field1=1 AND xxx.field5>10 AND TABLE_XXX.field2 IN (1,3) AND xxx.field6<=10 AND TABLE_XXX.field3 BETWEEN 5 AND 20 AND xxx.field7>=10", shardingRule, parameters);
UpdateSQLContext sqlContext = (UpdateSQLContext) statementParser.parseStatement();
assertUpdateStatement(sqlContext);
assertUpdateStatementWithoutParameter(sqlContext);
assertThat(sqlContext.getSqlBuilder().toString(),
is("UPDATE [Token(TABLE_XXX)] xxx SET [Token(TABLE_XXX)].field1=field1+1,xxx.field2=2 WHERE [Token(TABLE_XXX)].field4<10 "
+ "AND [Token(TABLE_XXX)].field1=1 AND xxx.field5>10 AND [Token(TABLE_XXX)].field2 IN (1,3) AND xxx.field6<=10 AND [Token(TABLE_XXX)].field3 BETWEEN 5 AND 20 AND xxx.field7>=10"));
}
private void assertUpdateStatementWithoutParameter(final UpdateSQLContext sqlContext) {
assertThat(sqlContext.getTables().get(0).getName(), is("TABLE_XXX"));
assertThat(sqlContext.getTables().get(0).getAlias().get(), is("xxx"));
ConditionContext.Condition condition1 = sqlContext.getConditionContext().find("TABLE_XXX", "field1").get();
assertThat(condition1.getOperator(), is(ShardingOperator.EQUAL));
assertThat(condition1.getValues().size(), is(1));
assertThat(condition1.getValues().get(0), is((Comparable) 1));
ConditionContext.Condition condition2 = sqlContext.getConditionContext().find("TABLE_XXX", "field2").get();
assertThat(condition2.getOperator(), is(ShardingOperator.IN));
assertThat(condition2.getValues().size(), is(2));
assertThat(condition2.getValues().get(0), is((Comparable) 1));
assertThat(condition2.getValues().get(1), is((Comparable) 3));
ConditionContext.Condition condition3 = sqlContext.getConditionContext().find("TABLE_XXX", "field3").get();
assertThat(condition3.getOperator(), is(ShardingOperator.BETWEEN));
assertThat(condition3.getValues().size(), is(2));
assertThat(condition3.getValues().get(0), is((Comparable) 5));
assertThat(condition3.getValues().get(1), is((Comparable) 20));
}
@Test
public void parseWithParameter() {
ShardingRule shardingRule = createShardingRule();
......@@ -68,28 +88,31 @@ public final class UpdateStatementParserTest extends AbstractStatementParserTest
"UPDATE TABLE_XXX AS xxx SET field1=field1+? WHERE field4<? AND xxx.field1=? AND field5>? AND xxx.field2 IN (?, ?) AND field6<=? AND xxx.field3 BETWEEN ? AND ? AND field7>=?",
shardingRule, parameters);
UpdateSQLContext sqlContext = (UpdateSQLContext) statementParser.parseStatement();
assertUpdateStatement(sqlContext);
assertUpdateStatementWitParameter(sqlContext);
assertThat(sqlContext.getSqlBuilder().toString(), is("UPDATE [Token(TABLE_XXX)] AS xxx SET field1=field1+? "
+ "WHERE field4<? AND xxx.field1=? AND field5>? AND xxx.field2 IN (?, ?) AND field6<=? AND xxx.field3 BETWEEN ? AND ? AND field7>=?"));
}
private void assertUpdateStatement(final UpdateSQLContext sqlContext) {
private void assertUpdateStatementWitParameter(final UpdateSQLContext sqlContext) {
assertThat(sqlContext.getTables().get(0).getName(), is("TABLE_XXX"));
assertThat(sqlContext.getTables().get(0).getAlias().get(), is("xxx"));
ConditionContext.Condition condition1 = sqlContext.getConditionContext().find("TABLE_XXX", "field1").get();
assertThat(condition1.getOperator(), is(ShardingOperator.EQUAL));
assertThat(condition1.getValues().size(), is(1));
assertThat(condition1.getValues().get(0), is((Comparable) 1));
assertTrue(condition1.getValues().isEmpty());
assertThat(condition1.getValueIndices().size(), is(1));
assertThat(condition1.getValueIndices().get(0), is(2));
ConditionContext.Condition condition2 = sqlContext.getConditionContext().find("TABLE_XXX", "field2").get();
assertThat(condition2.getOperator(), is(ShardingOperator.IN));
assertThat(condition2.getValues().size(), is(2));
assertThat(condition2.getValues().get(0), is((Comparable) 1));
assertThat(condition2.getValues().get(1), is((Comparable) 3));
assertTrue(condition2.getValues().isEmpty());
assertThat(condition2.getValueIndices().size(), is(2));
assertThat(condition2.getValueIndices().get(0), is(4));
assertThat(condition2.getValueIndices().get(1), is(5));
ConditionContext.Condition condition3 = sqlContext.getConditionContext().find("TABLE_XXX", "field3").get();
assertThat(condition3.getOperator(), is(ShardingOperator.BETWEEN));
assertThat(condition3.getValues().size(), is(2));
assertThat(condition3.getValues().get(0), is((Comparable) 5));
assertThat(condition3.getValues().get(1), is((Comparable) 20));
assertTrue(condition3.getValues().isEmpty());
assertThat(condition3.getValueIndices().size(), is(2));
assertThat(condition3.getValueIndices().get(0), is(7));
assertThat(condition3.getValueIndices().get(1), is(8));
}
@Test(expected = SQLParsingUnsupportedException.class)
......
......@@ -23,6 +23,7 @@ import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ConditionContext
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ShardingColumnContext;
import org.junit.Test;
import java.util.Collections;
import java.util.Iterator;
import static org.hamcrest.core.Is.is;
......@@ -34,13 +35,13 @@ public class SingleRouterUtilTest {
public void testConvertConditionToShardingValue() throws Exception {
ConditionContext.Condition condition = new ConditionContext.Condition(new ShardingColumnContext("test", "test"), ShardingOperator.EQUAL);
condition.getValues().add(1);
ShardingValue<?> shardingValue = SingleRouterUtil.convertConditionToShardingValue(condition);
ShardingValue<?> shardingValue = SingleRouterUtil.convertConditionToShardingValue(condition, Collections.emptyList());
assertThat(shardingValue.getType(), is(ShardingValue.ShardingValueType.SINGLE));
assertThat((Integer) shardingValue.getValue(), is(1));
condition = new ConditionContext.Condition(new ShardingColumnContext("test", "test"), ShardingOperator.IN);
condition.getValues().add(1);
condition.getValues().add(2);
shardingValue = SingleRouterUtil.convertConditionToShardingValue(condition);
shardingValue = SingleRouterUtil.convertConditionToShardingValue(condition, Collections.emptyList());
assertThat(shardingValue.getType(), is(ShardingValue.ShardingValueType.LIST));
Iterator<?> iterator = shardingValue.getValues().iterator();
assertThat((Integer) iterator.next(), is(1));
......@@ -48,7 +49,7 @@ public class SingleRouterUtilTest {
condition = new ConditionContext.Condition(new ShardingColumnContext("test", "test"), ShardingOperator.BETWEEN);
condition.getValues().add(1);
condition.getValues().add(2);
shardingValue = SingleRouterUtil.convertConditionToShardingValue(condition);
shardingValue = SingleRouterUtil.convertConditionToShardingValue(condition, Collections.emptyList());
assertThat(shardingValue.getType(), is(ShardingValue.ShardingValueType.RANGE));
assertThat((Integer) shardingValue.getValueRange().lowerEndpoint(), is(1));
assertThat((Integer) shardingValue.getValueRange().upperEndpoint(), is(2));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册