未验证 提交 adf9c473 编写于 作者: N Neil Dong 提交者: GitHub

Issue5423 and issue5465 (#5590)

* BugFix:1.ON DUPLICATE KEY UPDATE Sub-Clause paramter dropped   https://github.com/apache/shardingsphere/issues/5210

BugFix: 2.ON DUPLICATE KEY UPDATE Sub-Clause encrypt logic: missing assistedQueryColumn.

* Add Apache License to org.apache.shardingsphere.sql.parser.binder.segment.insert.values.OnDuplicateUpdateContext

* 1.Add Tests for org.apache.shardingsphere.sql.parser.binder.segment.insert.values.OnDuplicateUpdateContext

2.Improve Tests for GroupedParameterBuilder

* 1.Improve Tests coverage of InsertStatementContextTest

* 1.Remove EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter implement of QueryWithCipherColumnAware

* 1.Improve Tests coverage of GroupedParameterBuilder ,OnDuplicateUpdateContext

* 1.Improve Tests coverage of GroupedParameterBuilder

* Add testcase ,   sharding test case will fail.

* Add testcase ,   sharding test case will fail.

* BugFix:
1.https://github.com/apache/shardingsphere/issues/5210
ON DUPLICATE KEY UPDATE Sub-Clause paramter dropped

BugFix:
2.https://github.com/apache/shardingsphere/issues/5465
1)ON DUPLICATE KEY UPDATE Sub-Clause encrypt logic: missing assistedQueryColumn.
2)REMOVE wrong implements of QueryWithCipherColumnAware in  org.apache.shardingsphere.encrypt.rewrite.parameter.impl.EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter

* merge upstream

* improve code style.

* improve code style.

* improve code style.

* improve code style.

* improve code style.

* improve code style.

* Remove redundant blank lines

* Update OnDuplicateUpdateContext.java

document format adjusted.

* Update OnDuplicateUpdateContextTest.java

ajust code style
上级 55438faf
......@@ -19,19 +19,17 @@ package org.apache.shardingsphere.encrypt.rewrite.parameter.impl;
import com.google.common.base.Preconditions;
import lombok.Setter;
import org.apache.shardingsphere.encrypt.rewrite.aware.QueryWithCipherColumnAware;
import org.apache.shardingsphere.encrypt.rewrite.parameter.EncryptParameterRewriter;
import org.apache.shardingsphere.encrypt.strategy.spi.Encryptor;
import org.apache.shardingsphere.encrypt.strategy.spi.QueryAssistedEncryptor;
import org.apache.shardingsphere.sql.parser.binder.segment.insert.values.OnDuplicateUpdateContext;
import org.apache.shardingsphere.sql.parser.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.sql.parser.binder.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.OnDuplicateKeyColumnsSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
import org.apache.shardingsphere.underlying.rewrite.parameter.builder.ParameterBuilder;
import org.apache.shardingsphere.underlying.rewrite.parameter.builder.impl.GroupedParameterBuilder;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
......@@ -39,9 +37,7 @@ import java.util.Optional;
* Insert on duplicate key update parameter rewriter for encrypt.
*/
@Setter
public final class EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter extends EncryptParameterRewriter<InsertStatementContext> implements QueryWithCipherColumnAware {
private boolean queryWithCipherColumn;
public final class EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter extends EncryptParameterRewriter<InsertStatementContext> {
@Override
protected boolean isNeedRewriteForEncrypt(final SQLStatementContext sqlStatementContext) {
......@@ -51,30 +47,34 @@ public final class EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter exten
@Override
public void rewrite(final ParameterBuilder parameterBuilder, final InsertStatementContext insertStatementContext, final List<Object> parameters) {
String tableName = insertStatementContext.getSqlStatement().getTable().getTableName().getIdentifier().getValue();
Preconditions.checkState(insertStatementContext.getSqlStatement().getOnDuplicateKeyColumns().isPresent());
OnDuplicateKeyColumnsSegment onDuplicateKeyColumnsSegment = insertStatementContext.getSqlStatement().getOnDuplicateKeyColumns().get();
Collection<AssignmentSegment> onDuplicateKeyColumnsSegments = onDuplicateKeyColumnsSegment.getColumns();
if (onDuplicateKeyColumnsSegments.isEmpty()) {
return;
}
GroupedParameterBuilder groupedParameterBuilder = (GroupedParameterBuilder) parameterBuilder;
for (AssignmentSegment each : onDuplicateKeyColumnsSegments) {
ExpressionSegment expressionSegment = each.getValue();
Object cipherColumnValue;
Object plainColumnValue = null;
if (expressionSegment instanceof ParameterMarkerExpressionSegment) {
plainColumnValue = parameters.get(((ParameterMarkerExpressionSegment) expressionSegment).getParameterMarkerIndex());
}
if (queryWithCipherColumn) {
Optional<Encryptor> encryptor = getEncryptRule().findEncryptor(tableName, each.getColumn().getIdentifier().getValue());
if (encryptor.isPresent()) {
cipherColumnValue = encryptor.get().encrypt(plainColumnValue);
groupedParameterBuilder.getOnDuplicateKeyUpdateAddedParameters().add(cipherColumnValue);
OnDuplicateUpdateContext onDuplicateKeyUpdateValueContext = insertStatementContext.getOnDuplicateKeyUpdateValueContext();
for (int index = 0; index < onDuplicateKeyUpdateValueContext.getValueExpressions().size(); index++) {
final int columnIndex = index;
String encryptLogicColumnName = onDuplicateKeyUpdateValueContext.getColumn(columnIndex).getIdentifier().getValue();
Optional<Encryptor> encryptorOptional = getEncryptRule().findEncryptor(tableName, encryptLogicColumnName);
encryptorOptional.ifPresent(encryptor -> {
Object plainColumnValue = onDuplicateKeyUpdateValueContext.getValue(columnIndex);
Object cipherColumnValue = encryptorOptional.get().encrypt(plainColumnValue);
groupedParameterBuilder.getOnDuplicateKeyUpdateParametersBuilder().addReplacedParameters(columnIndex, cipherColumnValue);
Collection<Object> addedParameters = new LinkedList<>();
if (encryptor instanceof QueryAssistedEncryptor) {
Optional<String> assistedColumnName = getEncryptRule().findAssistedQueryColumn(tableName, encryptLogicColumnName);
Preconditions.checkArgument(assistedColumnName.isPresent(), "Can not find assisted query Column Name");
addedParameters.add(((QueryAssistedEncryptor) encryptor).queryAssistedEncrypt(plainColumnValue.toString()));
}
if (getEncryptRule().findPlainColumn(tableName, encryptLogicColumnName).isPresent()) {
addedParameters.add(plainColumnValue);
}
if (!addedParameters.isEmpty()) {
if (!groupedParameterBuilder.getOnDuplicateKeyUpdateParametersBuilder().getAddedIndexAndParameters().containsKey(columnIndex + 1)) {
groupedParameterBuilder.getOnDuplicateKeyUpdateParametersBuilder().getAddedIndexAndParameters().put(columnIndex + 1, new LinkedList<>());
}
groupedParameterBuilder.getOnDuplicateKeyUpdateParametersBuilder().getAddedIndexAndParameters().get(columnIndex + 1).addAll(addedParameters);
}
}
if (null != plainColumnValue) {
groupedParameterBuilder.getOnDuplicateKeyUpdateAddedParameters().add(plainColumnValue);
}
});
}
}
}
......@@ -77,6 +77,7 @@ public final class EncryptInsertOnUpdateTokenGenerator extends BaseEncryptSQLTok
EncryptParameterAssignmentToken result = new EncryptParameterAssignmentToken(assignmentSegment.getColumn().getStartIndex(), assignmentSegment.getStopIndex());
String columnName = assignmentSegment.getColumn().getIdentifier().getValue();
addCipherColumn(tableName, columnName, result);
addAssistedQueryColumn(tableName, columnName, result);
addPlainColumn(tableName, columnName, result);
return result;
}
......@@ -84,6 +85,7 @@ public final class EncryptInsertOnUpdateTokenGenerator extends BaseEncryptSQLTok
private EncryptAssignmentToken generateLiteralSQLToken(final String tableName, final AssignmentSegment assignmentSegment) {
EncryptLiteralAssignmentToken result = new EncryptLiteralAssignmentToken(assignmentSegment.getColumn().getStartIndex(), assignmentSegment.getStopIndex());
addCipherAssignment(tableName, assignmentSegment, result);
addAssistedQueryAssignment(tableName, assignmentSegment, result);
addPlainAssignment(tableName, assignmentSegment, result);
return result;
}
......@@ -91,6 +93,10 @@ public final class EncryptInsertOnUpdateTokenGenerator extends BaseEncryptSQLTok
private void addCipherColumn(final String tableName, final String columnName, final EncryptParameterAssignmentToken token) {
token.addColumnName(getEncryptRule().getCipherColumn(tableName, columnName));
}
private void addAssistedQueryColumn(final String tableName, final String columnName, final EncryptParameterAssignmentToken token) {
getEncryptRule().findAssistedQueryColumn(tableName, columnName).ifPresent(token::addColumnName);
}
private void addPlainColumn(final String tableName, final String columnName, final EncryptParameterAssignmentToken token) {
getEncryptRule().findPlainColumn(tableName, columnName).ifPresent(token::addColumnName);
......@@ -102,6 +108,15 @@ public final class EncryptInsertOnUpdateTokenGenerator extends BaseEncryptSQLTok
token.addAssignment(getEncryptRule().getCipherColumn(tableName, assignmentSegment.getColumn().getIdentifier().getValue()), cipherValue);
}
private void addAssistedQueryAssignment(final String tableName, final AssignmentSegment assignmentSegment, final EncryptLiteralAssignmentToken token) {
getEncryptRule().findAssistedQueryColumn(tableName, assignmentSegment.getColumn().getIdentifier().getValue()).ifPresent(assistedQueryColumn -> {
Object originalValue = ((LiteralExpressionSegment) assignmentSegment.getValue()).getLiterals();
Object assistedQueryValue = getEncryptRule().getEncryptAssistedQueryValues(tableName, assignmentSegment.getColumn().getIdentifier().getValue(), Collections.singletonList(originalValue))
.iterator().next();
token.addAssignment(assistedQueryColumn, assistedQueryValue);
});
}
private void addPlainAssignment(final String tableName, final AssignmentSegment assignmentSegment, final EncryptLiteralAssignmentToken token) {
Object originalValue = ((LiteralExpressionSegment) assignmentSegment.getValue()).getLiterals();
getEncryptRule().findPlainColumn(tableName, assignmentSegment.getColumn().getIdentifier().getValue()).ifPresent(plainColumn -> token.addAssignment(plainColumn, originalValue));
......
......@@ -44,12 +44,17 @@
<rewrite-assertion id="insert_values_with_on_duplicate_key_update_with_columns_with_plain_for_parameters" db-type="MySQL">
<input sql="INSERT INTO t_account_bak(account_id, certificate_number, password, amount, status) VALUES (?, ?, ?, ?, ?), (2, '222X', 'bbb', 2000, 'OK'), (?, ?, ?, ?, ?), (4, '444X', 'ddd', 4000, 'OK') ON DUPLICATE KEY UPDATE password = ?" parameters="1, 111X, aaa, 1000, OK, 3, 333X, ccc, 3000, OK, ccc_update" />
<output sql="INSERT INTO t_account_bak(account_id, cipher_certificate_number, assisted_query_certificate_number, plain_certificate_number, cipher_password, assisted_query_password, plain_password, cipher_amount, plain_amount, status) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?), (2, 'encrypt_222X', 'assisted_query_222X', '222X', 'encrypt_bbb', 'assisted_query_bbb', 'bbb', 'encrypt_2000', 2000, 'OK'), (?, ?, ?, ?, ?, ?, ?, ?, ?, ?), (4, 'encrypt_444X', 'assisted_query_444X', '444X', 'encrypt_ddd', 'assisted_query_ddd', 'ddd', 'encrypt_4000', 4000, 'OK') ON DUPLICATE KEY UPDATE cipher_password = ?, plain_password = ?" parameters="1, encrypt_111X, assisted_query_111X, 111X, encrypt_aaa, assisted_query_aaa, aaa, encrypt_1000, 1000, OK, 3, encrypt_333X, assisted_query_333X, 333X, encrypt_ccc, assisted_query_ccc, ccc, encrypt_3000, 3000, OK, encrypt_ccc_update, ccc_update" />
<output sql="INSERT INTO t_account_bak(account_id, cipher_certificate_number, assisted_query_certificate_number, plain_certificate_number, cipher_password, assisted_query_password, plain_password, cipher_amount, plain_amount, status) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?), (2, 'encrypt_222X', 'assisted_query_222X', '222X', 'encrypt_bbb', 'assisted_query_bbb', 'bbb', 'encrypt_2000', 2000, 'OK'), (?, ?, ?, ?, ?, ?, ?, ?, ?, ?), (4, 'encrypt_444X', 'assisted_query_444X', '444X', 'encrypt_ddd', 'assisted_query_ddd', 'ddd', 'encrypt_4000', 4000, 'OK') ON DUPLICATE KEY UPDATE cipher_password = ?, assisted_query_password = ?, plain_password = ?" parameters="1, encrypt_111X, assisted_query_111X, 111X, encrypt_aaa, assisted_query_aaa, aaa, encrypt_1000, 1000, OK, 3, encrypt_333X, assisted_query_333X, 333X, encrypt_ccc, assisted_query_ccc, ccc, encrypt_3000, 3000, OK, encrypt_ccc_update, assisted_query_ccc_update, ccc_update" />
</rewrite-assertion>
<rewrite-assertion id="insert_values_with_on_duplicate_key_update_with_columns_with_plain_for_literals" db-type="MySQL">
<input sql="INSERT INTO t_account_bak(account_id, certificate_number, password, amount, status) VALUES (1, '111X', 'aaa', 1000, 'OK'), (2, '222X', 'bbb', 2000, 'OK'), (3, '333X', 'ccc', 3000, 'OK'), (4, '444X', 'ddd', 4000, 'OK') ON DUPLICATE KEY UPDATE password = 'ccc_update'" />
<output sql="INSERT INTO t_account_bak(account_id, cipher_certificate_number, assisted_query_certificate_number, plain_certificate_number, cipher_password, assisted_query_password, plain_password, cipher_amount, plain_amount, status) VALUES (1, 'encrypt_111X', 'assisted_query_111X', '111X', 'encrypt_aaa', 'assisted_query_aaa', 'aaa', 'encrypt_1000', 1000, 'OK'), (2, 'encrypt_222X', 'assisted_query_222X', '222X', 'encrypt_bbb', 'assisted_query_bbb', 'bbb', 'encrypt_2000', 2000, 'OK'), (3, 'encrypt_333X', 'assisted_query_333X', '333X', 'encrypt_ccc', 'assisted_query_ccc', 'ccc', 'encrypt_3000', 3000, 'OK'), (4, 'encrypt_444X', 'assisted_query_444X', '444X', 'encrypt_ddd', 'assisted_query_ddd', 'ddd', 'encrypt_4000', 4000, 'OK') ON DUPLICATE KEY UPDATE cipher_password = 'encrypt_ccc_update', plain_password = 'ccc_update'" />
<output sql="INSERT INTO t_account_bak(account_id, cipher_certificate_number, assisted_query_certificate_number, plain_certificate_number, cipher_password, assisted_query_password, plain_password, cipher_amount, plain_amount, status) VALUES (1, 'encrypt_111X', 'assisted_query_111X', '111X', 'encrypt_aaa', 'assisted_query_aaa', 'aaa', 'encrypt_1000', 1000, 'OK'), (2, 'encrypt_222X', 'assisted_query_222X', '222X', 'encrypt_bbb', 'assisted_query_bbb', 'bbb', 'encrypt_2000', 2000, 'OK'), (3, 'encrypt_333X', 'assisted_query_333X', '333X', 'encrypt_ccc', 'assisted_query_ccc', 'ccc', 'encrypt_3000', 3000, 'OK'), (4, 'encrypt_444X', 'assisted_query_444X', '444X', 'encrypt_ddd', 'assisted_query_ddd', 'ddd', 'encrypt_4000', 4000, 'OK') ON DUPLICATE KEY UPDATE cipher_password = 'encrypt_ccc_update', assisted_query_password = 'assisted_query_ccc_update', plain_password = 'ccc_update'" />
</rewrite-assertion>
<rewrite-assertion id="insert_values_with_on_duplicate_key_update_with_insert_value_literals_and_on_duplicate_parameterized" db-type="MySQL">
<input sql="INSERT INTO t_account_bak(account_id, certificate_number, password, amount, status) VALUES (1, '111X', 'aaa', 1000, 'OK'), (2, '222X', 'bbb', 2000, 'OK'), (3, '333X', 'ccc', 3000, 'OK'), (4, '444X', 'ddd', 4000, 'OK') ON DUPLICATE KEY UPDATE password = ?" parameters="ccc_update" />
<output sql="INSERT INTO t_account_bak(account_id, cipher_certificate_number, assisted_query_certificate_number, plain_certificate_number, cipher_password, assisted_query_password, plain_password, cipher_amount, plain_amount, status) VALUES (1, 'encrypt_111X', 'assisted_query_111X', '111X', 'encrypt_aaa', 'assisted_query_aaa', 'aaa', 'encrypt_1000', 1000, 'OK'), (2, 'encrypt_222X', 'assisted_query_222X', '222X', 'encrypt_bbb', 'assisted_query_bbb', 'bbb', 'encrypt_2000', 2000, 'OK'), (3, 'encrypt_333X', 'assisted_query_333X', '333X', 'encrypt_ccc', 'assisted_query_ccc', 'ccc', 'encrypt_3000', 3000, 'OK'), (4, 'encrypt_444X', 'assisted_query_444X', '444X', 'encrypt_ddd', 'assisted_query_ddd', 'ddd', 'encrypt_4000', 4000, 'OK') ON DUPLICATE KEY UPDATE cipher_password = ?, assisted_query_password = ?, plain_password = ?" parameters="encrypt_ccc_update, assisted_query_ccc_update, ccc_update" />
</rewrite-assertion>
<rewrite-assertion id="insert_values_with_columns_with_plain_for_literals">
......
......@@ -125,4 +125,10 @@
<input sql="INSERT INTO t_account SET amount = ?, status = ?" parameters="1000, OK" />
<output sql="INSERT INTO t_account_1 SET amount = ?, status = ?, account_id = ?" parameters="1000, OK, 1" />
</rewrite-assertion>
<rewrite-assertion id="insert_multiple_values_with_columns_with_id_for_parameters_and_on_duplicate_update" db-type="MySQL">
<input sql="INSERT INTO t_account VALUES (100, 1000, 'OK'), (101, 2000, 'OK'), (102, 1000, 'OK') ON DUPLICATE KEY UPDATE status = ?" parameters="OK_UPDATE" />
<output sql="INSERT INTO t_account_0 VALUES (100, 1000, 'OK'), (102, 1000, 'OK') ON DUPLICATE KEY UPDATE status = ?" parameters="OK_UPDATE" />
<output sql="INSERT INTO t_account_1 VALUES (101, 2000, 'OK') ON DUPLICATE KEY UPDATE status = ?" parameters="OK_UPDATE" />
</rewrite-assertion>
</rewrite-assertions>
......@@ -44,10 +44,14 @@ public final class ShardingSpherePreparedStatementTest extends AbstractShardingJ
private static final String INSERT_WITHOUT_GENERATE_KEY_SQL = "INSERT INTO t_order_item (order_id, user_id, status) VALUES (?, ?, ?)";
private static final String INSERT_ON_DUPLICATE_KEY_SQL = "INSERT INTO t_order_item (item_id, order_id, user_id, status) VALUES (?, ?, ?, ?), (?, ?, ?, ?) ON DUPLICATE KEY UPDATE status = ?";
private static final String SELECT_SQL_WITHOUT_PARAMETER_MARKER = "SELECT item_id FROM t_order_item WHERE user_id = %d AND order_id= %s AND status = 'BATCH'";
private static final String SELECT_SQL_WITH_PARAMETER_MARKER = "SELECT item_id FROM t_order_item WHERE user_id = ? AND order_id= ? AND status = 'BATCH'";
private static final String SELECT_SQL_WITH_PARAMETER_MARKER_RETURN_STATUS = "SELECT item_id, user_id, status FROM t_order_item WHERE order_id= ? AND user_id = ?";
private static final String UPDATE_SQL = "UPDATE t_order SET status = ? WHERE user_id = ? AND order_id = ?";
private static final String UPDATE_BATCH_SQL = "UPDATE t_order SET status=? WHERE status=?";
......@@ -83,7 +87,7 @@ public final class ShardingSpherePreparedStatementTest extends AbstractShardingJ
}
}
}
@Ignore
@Test
public void assertMultiValuesWithGenerateShardingKeyColumn() throws SQLException {
......@@ -126,7 +130,7 @@ public final class ShardingSpherePreparedStatementTest extends AbstractShardingJ
}
}
}
@Ignore
@Test
public void assertAddBatchMultiValuesWithGenerateShardingKeyColumn() throws SQLException {
......@@ -316,6 +320,75 @@ public final class ShardingSpherePreparedStatementTest extends AbstractShardingJ
}
}
@Test
public void assertAddOnDuplicateKey() throws SQLException {
int itemId = 1;
int userId1 = 101;
int userId2 = 102;
int orderId = 200;
String status = "init";
String updatedStatus = "updated on duplicate key";
try (Connection connection = getShardingSphereDataSource().getConnection();
PreparedStatement preparedStatement = connection.prepareStatement(INSERT_ON_DUPLICATE_KEY_SQL);
PreparedStatement queryStatement = connection.prepareStatement(SELECT_SQL_WITH_PARAMETER_MARKER_RETURN_STATUS)) {
preparedStatement.setInt(1, itemId);
preparedStatement.setInt(2, orderId);
preparedStatement.setInt(3, userId1);
preparedStatement.setString(4, status);
preparedStatement.setInt(5, itemId);
preparedStatement.setInt(6, orderId);
preparedStatement.setInt(7, userId2);
preparedStatement.setString(8, status);
preparedStatement.setString(9, updatedStatus);
int result = preparedStatement.executeUpdate();
assertThat(result, is(2));
queryStatement.setInt(1, orderId);
queryStatement.setInt(2, userId1);
try (ResultSet resultSet = queryStatement.executeQuery()) {
assertTrue(resultSet.next());
assertThat(resultSet.getInt(2), is(userId1));
assertThat(resultSet.getString(3), is(status));
}
queryStatement.setInt(1, orderId);
queryStatement.setInt(2, userId2);
try (ResultSet resultSet = queryStatement.executeQuery()) {
assertTrue(resultSet.next());
assertThat(resultSet.getInt(2), is(userId2));
assertThat(resultSet.getString(3), is(status));
}
}
try (Connection connection = getShardingSphereDataSource().getConnection();
PreparedStatement preparedStatement = connection.prepareStatement(INSERT_ON_DUPLICATE_KEY_SQL);
PreparedStatement queryStatement = connection.prepareStatement(SELECT_SQL_WITH_PARAMETER_MARKER_RETURN_STATUS)) {
preparedStatement.setInt(1, itemId);
preparedStatement.setInt(2, orderId);
preparedStatement.setInt(3, userId1);
preparedStatement.setString(4, status);
preparedStatement.setInt(5, itemId);
preparedStatement.setInt(6, orderId);
preparedStatement.setInt(7, userId2);
preparedStatement.setString(8, status);
preparedStatement.setString(9, updatedStatus);
int result = preparedStatement.executeUpdate();
assertThat(result, is(2));
queryStatement.setInt(1, orderId);
queryStatement.setInt(2, userId1);
try (ResultSet resultSet = queryStatement.executeQuery()) {
assertTrue(resultSet.next());
assertThat(resultSet.getInt(2), is(userId1));
assertThat(resultSet.getString(3), is(updatedStatus));
}
queryStatement.setInt(1, orderId);
queryStatement.setInt(2, userId2);
try (ResultSet resultSet = queryStatement.executeQuery()) {
assertTrue(resultSet.next());
assertThat(resultSet.getInt(2), is(userId2));
assertThat(resultSet.getString(3), is(updatedStatus));
}
}
}
@Test
public void assertUpdateBatch() throws SQLException {
try (
......@@ -337,7 +410,7 @@ public final class ShardingSpherePreparedStatementTest extends AbstractShardingJ
assertThat(result[2], is(4));
}
}
@Test
public void assertExecuteGetResultSet() throws SQLException {
try (PreparedStatement preparedStatement = getShardingSphereDataSource().getConnection().prepareStatement(UPDATE_SQL)) {
......@@ -348,7 +421,7 @@ public final class ShardingSpherePreparedStatementTest extends AbstractShardingJ
assertNull(preparedStatement.getResultSet());
}
}
@Test
public void assertExecuteUpdateGetResultSet() throws SQLException {
try (PreparedStatement preparedStatement = getShardingSphereDataSource().getConnection().prepareStatement(UPDATE_SQL)) {
......@@ -359,7 +432,7 @@ public final class ShardingSpherePreparedStatementTest extends AbstractShardingJ
assertNull(preparedStatement.getResultSet());
}
}
@Test
public void assertClearBatch() throws SQLException {
try (
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.shardingsphere.sql.parser.binder.segment.insert.values;
import lombok.Getter;
import lombok.ToString;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
@Getter
@ToString
public class OnDuplicateUpdateContext {
private final int parametersCount;
private final List<ExpressionSegment> valueExpressions;
private final List<Object> parameters;
private final List<ColumnSegment> columns;
public OnDuplicateUpdateContext(final Collection<AssignmentSegment> assignments, final List<Object> parameters, final int parametersOffset) {
List<ExpressionSegment> expressionSegments = assignments.stream().map(AssignmentSegment::getValue).collect(Collectors.toList());
parametersCount = calculateParametersCount(expressionSegments);
valueExpressions = getValueExpressions(expressionSegments);
this.parameters = getParameters(parameters, parametersOffset);
columns = assignments.stream().map(AssignmentSegment::getColumn).collect(Collectors.toList());
}
private int calculateParametersCount(final Collection<ExpressionSegment> assignments) {
int result = 0;
for (ExpressionSegment each : assignments) {
if (each instanceof ParameterMarkerExpressionSegment) {
result++;
}
}
return result;
}
private List<ExpressionSegment> getValueExpressions(final Collection<ExpressionSegment> assignments) {
List<ExpressionSegment> result = new ArrayList<>(assignments.size());
result.addAll(assignments);
return result;
}
private List<Object> getParameters(final List<Object> parameters, final int parametersOffset) {
if (0 == parametersCount) {
return Collections.emptyList();
}
List<Object> result = new ArrayList<>(parametersCount);
result.addAll(parameters.subList(parametersOffset, parametersOffset + parametersCount));
return result;
}
/**
* Get value.
*
* @param index index
* @return value
*/
public Object getValue(final int index) {
ExpressionSegment valueExpression = valueExpressions.get(index);
return valueExpression instanceof ParameterMarkerExpressionSegment ? parameters.get(getParameterIndex(valueExpression)) : ((LiteralExpressionSegment) valueExpression).getLiterals();
}
private int getParameterIndex(final ExpressionSegment valueExpression) {
int result = 0;
for (ExpressionSegment each : valueExpressions) {
if (valueExpression == each) {
return result;
}
if (each instanceof ParameterMarkerExpressionSegment) {
result++;
}
}
throw new IllegalArgumentException("Can not get parameter index.");
}
/**
* Get on duplicate key update column by index of this clause.
*
* @param index index
* @return columnSegment
*/
public ColumnSegment getColumn(final int index) {
return columns.get(index);
}
}
......@@ -23,19 +23,23 @@ import org.apache.shardingsphere.sql.parser.binder.metadata.schema.SchemaMetaDat
import org.apache.shardingsphere.sql.parser.binder.segment.insert.keygen.GeneratedKeyContext;
import org.apache.shardingsphere.sql.parser.binder.segment.insert.keygen.engine.GeneratedKeyContextEngine;
import org.apache.shardingsphere.sql.parser.binder.segment.insert.values.InsertValueContext;
import org.apache.shardingsphere.sql.parser.binder.segment.insert.values.OnDuplicateUpdateContext;
import org.apache.shardingsphere.sql.parser.binder.segment.table.TablesContext;
import org.apache.shardingsphere.sql.parser.binder.statement.CommonSQLStatementContext;
import org.apache.shardingsphere.sql.parser.binder.type.TableAvailable;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.sql.statement.dml.InsertStatement;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Insert SQL statement context.
......@@ -50,30 +54,43 @@ public final class InsertStatementContext extends CommonSQLStatementContext<Inse
private final List<InsertValueContext> insertValueContexts;
private final OnDuplicateUpdateContext onDuplicateKeyUpdateValueContext;
private final GeneratedKeyContext generatedKeyContext;
public InsertStatementContext(final SchemaMetaData schemaMetaData, final List<Object> parameters, final InsertStatement sqlStatement) {
super(sqlStatement);
tablesContext = new TablesContext(sqlStatement.getTable());
columnNames = sqlStatement.useDefaultColumns() ? schemaMetaData.getAllColumnNames(sqlStatement.getTable().getTableName().getIdentifier().getValue()) : sqlStatement.getColumnNames();
insertValueContexts = getInsertValueContexts(parameters);
AtomicInteger parametersOffset = new AtomicInteger(0);
insertValueContexts = getInsertValueContexts(parameters, parametersOffset);
onDuplicateKeyUpdateValueContext = getOnDuplicateKeyUpdateValueContext(parameters, parametersOffset).orElse(null);
generatedKeyContext = new GeneratedKeyContextEngine(schemaMetaData).createGenerateKeyContext(parameters, sqlStatement).orElse(null);
}
private List<InsertValueContext> getInsertValueContexts(final List<Object> parameters) {
private List<InsertValueContext> getInsertValueContexts(final List<Object> parameters, final AtomicInteger parametersOffset) {
List<InsertValueContext> result = new LinkedList<>();
int parametersOffset = 0;
for (Collection<ExpressionSegment> each : getSqlStatement().getAllValueExpressions()) {
InsertValueContext insertValueContext = new InsertValueContext(each, parameters, parametersOffset);
InsertValueContext insertValueContext = new InsertValueContext(each, parameters, parametersOffset.get());
result.add(insertValueContext);
parametersOffset += insertValueContext.getParametersCount();
parametersOffset.addAndGet(insertValueContext.getParametersCount());
}
return result;
}
private Optional<OnDuplicateUpdateContext> getOnDuplicateKeyUpdateValueContext(final List<Object> parameters, final AtomicInteger parametersOffset) {
if (!getSqlStatement().getOnDuplicateKeyColumns().isPresent()) {
return Optional.empty();
}
Collection<AssignmentSegment> onDuplicateKeyColumns = getSqlStatement().getOnDuplicateKeyColumns().get().getColumns();
OnDuplicateUpdateContext onDuplicateUpdateContext = new OnDuplicateUpdateContext(onDuplicateKeyColumns, parameters, parametersOffset.get());
parametersOffset.addAndGet(onDuplicateUpdateContext.getParametersCount());
return Optional.of(onDuplicateUpdateContext);
}
/**
* Get column names for descending order.
*
*
* @return column names for descending order
*/
public Iterator<String> getDescendingColumnNames() {
......@@ -82,7 +99,7 @@ public final class InsertStatementContext extends CommonSQLStatementContext<Inse
/**
* Get grouped parameters.
*
*
* @return grouped parameters
*/
public List<List<Object>> getGroupedParameters() {
......@@ -93,9 +110,21 @@ public final class InsertStatementContext extends CommonSQLStatementContext<Inse
return result;
}
/**
* Get on duplicate key update parameters.
*
* @return on duplicate key update parameters
*/
public List<Object> getOnDuplicateKeyUpdateParameters() {
if (null == onDuplicateKeyUpdateValueContext) {
return new ArrayList<>(0);
}
return onDuplicateKeyUpdateValueContext.getParameters();
}
/**
* Get generated key context.
*
*
* @return generated key context
*/
public Optional<GeneratedKeyContext> getGeneratedKeyContext() {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.shardingsphere.sql.parser.binder.segment.insert.values;
import com.google.common.collect.Lists;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.SimpleExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.value.identifier.IdentifierValue;
import org.junit.Assert;
import org.junit.Test;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
public final class OnDuplicateUpdateContextTest {
@SuppressWarnings("unchecked")
@Test
public void assertInstanceConstructedOk() throws NoSuchMethodException, InvocationTargetException, IllegalAccessException {
Collection<AssignmentSegment> assignments = Lists.newArrayList();
List<Object> parameters = Collections.emptyList();
int parametersOffset = 0;
OnDuplicateUpdateContext onDuplicateUpdateContext = new OnDuplicateUpdateContext(assignments, parameters, parametersOffset);
Method calculateParametersCountMethod = OnDuplicateUpdateContext.class.getDeclaredMethod("calculateParametersCount", Collection.class);
calculateParametersCountMethod.setAccessible(true);
int calculateParametersCountResult = (int) calculateParametersCountMethod.invoke(onDuplicateUpdateContext, new Object[]{assignments});
assertThat(onDuplicateUpdateContext.getParametersCount(), is(calculateParametersCountResult));
Method getValueExpressionsMethod = OnDuplicateUpdateContext.class.getDeclaredMethod("getValueExpressions", Collection.class);
getValueExpressionsMethod.setAccessible(true);
List<ExpressionSegment> getValueExpressionsResult = (List<ExpressionSegment>) getValueExpressionsMethod.invoke(onDuplicateUpdateContext, new Object[]{assignments});
assertThat(onDuplicateUpdateContext.getValueExpressions(), is(getValueExpressionsResult));
Method getParametersMethod = OnDuplicateUpdateContext.class.getDeclaredMethod("getParameters", List.class, int.class);
getParametersMethod.setAccessible(true);
List<Object> getParametersResult = (List<Object>) getParametersMethod.invoke(onDuplicateUpdateContext, new Object[]{parameters, parametersOffset});
assertThat(onDuplicateUpdateContext.getParameters(), is(getParametersResult));
}
@Test
public void assertGetValueWhenParameterMarker() {
Collection<AssignmentSegment> assignments = makeParameterMarkerExpressionAssignmentSegment();
String parameterValue1 = "test1";
String parameterValue2 = "test2";
List<Object> parameters = Lists.newArrayList(parameterValue1, parameterValue2);
int parametersOffset = 0;
OnDuplicateUpdateContext onDuplicateUpdateContext = new OnDuplicateUpdateContext(assignments, parameters, parametersOffset);
Object valueFromInsertValueContext1 = onDuplicateUpdateContext.getValue(0);
assertThat(valueFromInsertValueContext1, is(parameterValue1));
Object valueFromInsertValueContext2 = onDuplicateUpdateContext.getValue(1);
assertThat(valueFromInsertValueContext2, is(parameterValue2));
}
private Collection<AssignmentSegment> makeParameterMarkerExpressionAssignmentSegment() {
ParameterMarkerExpressionSegment parameterMarkerExpressionSegment = new ParameterMarkerExpressionSegment(0, 10, 5);
AssignmentSegment assignmentSegment1 = makeAssignmentSegment(parameterMarkerExpressionSegment);
ParameterMarkerExpressionSegment parameterMarkerExpressionSegment2 = new ParameterMarkerExpressionSegment(0, 10, 6);
AssignmentSegment assignmentSegment2 = makeAssignmentSegment(parameterMarkerExpressionSegment2);
return Lists.newArrayList(assignmentSegment1, assignmentSegment2);
}
@Test
public void assertGetValueWhenLiteralExpressionSegment() {
Object literalObject = new Object();
Collection<AssignmentSegment> assignments = makeLiteralExpressionSegment(literalObject);
List<Object> parameters = Collections.emptyList();
OnDuplicateUpdateContext onDuplicateUpdateContext = new OnDuplicateUpdateContext(assignments, parameters, 0);
Object valueFromInsertValueContext = onDuplicateUpdateContext.getValue(0);
assertThat(valueFromInsertValueContext, is(literalObject));
}
private Collection<AssignmentSegment> makeLiteralExpressionSegment(final Object literalObject) {
LiteralExpressionSegment parameterLiteralExpression = new LiteralExpressionSegment(0, 10, literalObject);
AssignmentSegment assignmentSegment = makeAssignmentSegment(parameterLiteralExpression);
return Collections.singleton(assignmentSegment);
}
private AssignmentSegment makeAssignmentSegment(final SimpleExpressionSegment expressionSegment) {
int doesNotMatterLexicalIndex = 0;
String doesNotMatterColumnName = "columnNameStr";
ColumnSegment column = new ColumnSegment(doesNotMatterLexicalIndex, doesNotMatterLexicalIndex, new IdentifierValue(doesNotMatterColumnName));
return new AssignmentSegment(doesNotMatterLexicalIndex, doesNotMatterLexicalIndex, column, expressionSegment);
}
@Test
public void assertGetParameterIndex() throws NoSuchMethodException, IllegalAccessException {
Collection<AssignmentSegment> assignments = Lists.newArrayList();
List<Object> parameters = Collections.emptyList();
int parametersOffset = 0;
OnDuplicateUpdateContext onDuplicateUpdateContext = new OnDuplicateUpdateContext(assignments, parameters, parametersOffset);
Method getParameterIndexMethod = OnDuplicateUpdateContext.class.getDeclaredMethod("getParameterIndex", ExpressionSegment.class);
getParameterIndexMethod.setAccessible(true);
ParameterMarkerExpressionSegment notExistsExpressionSegment = new ParameterMarkerExpressionSegment(0, 0, 0);
Throwable targetException = null;
try {
getParameterIndexMethod.invoke(onDuplicateUpdateContext, notExistsExpressionSegment);
} catch (InvocationTargetException e) {
targetException = e.getTargetException();
}
assertTrue("expected throw IllegalArgumentException", targetException instanceof IllegalArgumentException);
}
@Test
public void assertGetColumn() {
Object literalObject = new Object();
Collection<AssignmentSegment> assignments = makeLiteralExpressionSegment(literalObject);
List<Object> parameters = Collections.emptyList();
OnDuplicateUpdateContext onDuplicateUpdateContext = new OnDuplicateUpdateContext(assignments, parameters, 0);
ColumnSegment column = onDuplicateUpdateContext.getColumn(0);
Assert.assertThat(column, is(assignments.iterator().next().getColumn()));
}
}
......@@ -19,9 +19,11 @@ package org.apache.shardingsphere.sql.parser.binder.statement.impl;
import org.apache.shardingsphere.sql.parser.binder.metadata.schema.SchemaMetaData;
import org.apache.shardingsphere.sql.parser.binder.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.InsertValuesSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.InsertColumnsSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.OnDuplicateKeyColumnsSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.generic.table.SimpleTableSegment;
......@@ -32,6 +34,7 @@ import org.junit.Test;
import java.util.Arrays;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
......@@ -62,7 +65,7 @@ public final class InsertStatementContextTest {
}
@Test
public void assertGetGroupedParameters() {
public void assertGetGroupedParametersWithoutOnDuplicateParameter() {
SchemaMetaData schemaMetaData = mock(SchemaMetaData.class);
when(schemaMetaData.getAllColumnNames("tbl")).thenReturn(Arrays.asList("id", "name", "status"));
InsertStatement insertStatement = new InsertStatement();
......@@ -70,6 +73,22 @@ public final class InsertStatementContextTest {
setUpInsertValues(insertStatement);
InsertStatementContext actual = new InsertStatementContext(schemaMetaData, Arrays.asList(1, "Tom", 2, "Jerry"), insertStatement);
assertThat(actual.getGroupedParameters().size(), is(2));
assertNull(actual.getOnDuplicateKeyUpdateValueContext());
assertThat(actual.getOnDuplicateKeyUpdateParameters().size(), is(0));
}
@Test
public void assertGetGroupedParametersWithOnDuplicateParameters() {
SchemaMetaData schemaMetaData = mock(SchemaMetaData.class);
when(schemaMetaData.getAllColumnNames("tbl")).thenReturn(Arrays.asList("id", "name", "status"));
InsertStatement insertStatement = new InsertStatement();
insertStatement.setTable(new SimpleTableSegment(0, 0, new IdentifierValue("tbl")));
setUpInsertValues(insertStatement);
setUpOnDuplicateValues(insertStatement);
InsertStatementContext actual = new InsertStatementContext(schemaMetaData, Arrays.asList(1, "Tom", 2, "Jerry", "onDuplicateKeyUpdateColumnValue"), insertStatement);
assertThat(actual.getGroupedParameters().size(), is(2));
assertThat(actual.getOnDuplicateKeyUpdateValueContext().getColumns().size(), is(2));
assertThat(actual.getOnDuplicateKeyUpdateParameters().size(), is(1));
}
private void setUpInsertValues(final InsertStatement insertStatement) {
......@@ -79,6 +98,21 @@ public final class InsertStatementContextTest {
new ParameterMarkerExpressionSegment(0, 0, 3), new ParameterMarkerExpressionSegment(0, 0, 4), new LiteralExpressionSegment(0, 0, "init"))));
}
private void setUpOnDuplicateValues(final InsertStatement insertStatement) {
AssignmentSegment parameterMarkerExpressionAssignment = new AssignmentSegment(0, 0,
new ColumnSegment(0, 0, new IdentifierValue("on_duplicate_key_update_column_1")),
new ParameterMarkerExpressionSegment(0, 0, 4)
);
AssignmentSegment literalExpressionAssignment = new AssignmentSegment(0, 0,
new ColumnSegment(0, 0, new IdentifierValue("on_duplicate_key_update_column_2")),
new LiteralExpressionSegment(0, 0, 5)
);
OnDuplicateKeyColumnsSegment onDuplicateKeyColumnsSegment = new OnDuplicateKeyColumnsSegment(0, 0, Arrays.asList(
parameterMarkerExpressionAssignment, literalExpressionAssignment
));
insertStatement.setOnDuplicateKeyColumns(onDuplicateKeyColumnsSegment);
}
private void assertInsertStatementContext(final InsertStatementContext actual) {
assertThat(actual.getColumnNames(), is(Arrays.asList("id", "name", "status")));
assertThat(actual.getInsertValueContexts().size(), is(2));
......
......@@ -198,6 +198,7 @@ public final class MySQLDMLVisitor extends MySQLVisitor implements DMLVisitor {
Collection<AssignmentSegment> columns = new LinkedList<>();
for (AssignmentContext each : ctx.assignment()) {
columns.add((AssignmentSegment) visit(each));
visit(each.assignmentValue());
}
return new OnDuplicateKeyColumnsSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), columns);
}
......
......@@ -62,12 +62,13 @@ public final class SQLRewriteContext {
this.parameters = parameters;
addSQLTokenGenerators(new DefaultTokenGeneratorBuilder().getSQLTokenGenerators());
parameterBuilder = sqlStatementContext instanceof InsertStatementContext
? new GroupedParameterBuilder(((InsertStatementContext) sqlStatementContext).getGroupedParameters()) : new StandardParameterBuilder(parameters);
? new GroupedParameterBuilder(((InsertStatementContext) sqlStatementContext).getGroupedParameters(), ((InsertStatementContext) sqlStatementContext).getOnDuplicateKeyUpdateParameters())
: new StandardParameterBuilder(parameters);
}
/**
* Add SQL token generators.
*
*
* @param sqlTokenGenerators SQL token generators
*/
public void addSQLTokenGenerators(final Collection<SQLTokenGenerator> sqlTokenGenerators) {
......
......@@ -20,8 +20,14 @@ package org.apache.shardingsphere.underlying.rewrite.engine;
import org.apache.shardingsphere.underlying.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.underlying.rewrite.engine.result.GenericSQLRewriteResult;
import org.apache.shardingsphere.underlying.rewrite.engine.result.SQLRewriteUnit;
import org.apache.shardingsphere.underlying.rewrite.parameter.builder.ParameterBuilder;
import org.apache.shardingsphere.underlying.rewrite.parameter.builder.impl.GroupedParameterBuilder;
import org.apache.shardingsphere.underlying.rewrite.parameter.builder.impl.StandardParameterBuilder;
import org.apache.shardingsphere.underlying.rewrite.sql.impl.DefaultSQLBuilder;
import java.util.LinkedList;
import java.util.List;
/**
* Generic SQL rewrite engine.
*/
......@@ -34,6 +40,22 @@ public final class GenericSQLRewriteEngine {
* @return SQL rewrite result
*/
public GenericSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext) {
return new GenericSQLRewriteResult(new SQLRewriteUnit(new DefaultSQLBuilder(sqlRewriteContext).toSQL(), sqlRewriteContext.getParameterBuilder().getParameters()));
return new GenericSQLRewriteResult(new SQLRewriteUnit(new DefaultSQLBuilder(sqlRewriteContext).toSQL(), getParameters(sqlRewriteContext.getParameterBuilder())));
}
private List<Object> getParameters(final ParameterBuilder parameterBuilder) {
if (parameterBuilder instanceof StandardParameterBuilder) {
return parameterBuilder.getParameters();
}
List<Object> onDuplicateKeyUpdateParameters = ((GroupedParameterBuilder) parameterBuilder).getOnDuplicateKeyUpdateParametersBuilder().getParameters();
if (onDuplicateKeyUpdateParameters.isEmpty()) {
return parameterBuilder.getParameters();
}
List<Object> result = new LinkedList<>();
result.addAll(parameterBuilder.getParameters());
result.addAll(onDuplicateKeyUpdateParameters);
return result;
}
}
......@@ -55,9 +55,22 @@ public final class RouteSQLRewriteEngine {
}
private List<Object> getParameters(final ParameterBuilder parameterBuilder, final RouteResult routeResult, final RouteUnit routeUnit) {
if (parameterBuilder instanceof StandardParameterBuilder || routeResult.getOriginalDataNodes().isEmpty() || parameterBuilder.getParameters().isEmpty()) {
if (parameterBuilder instanceof StandardParameterBuilder) {
return parameterBuilder.getParameters();
}
if (routeResult.getOriginalDataNodes().isEmpty()) {
List<Object> onDuplicateKeyUpdateParameters = ((GroupedParameterBuilder) parameterBuilder).getOnDuplicateKeyUpdateParametersBuilder().getParameters();
if (onDuplicateKeyUpdateParameters.isEmpty()) {
return parameterBuilder.getParameters();
}
List<Object> result = new LinkedList<>();
result.addAll(parameterBuilder.getParameters());
result.addAll(onDuplicateKeyUpdateParameters);
return result;
}
List<Object> result = new LinkedList<>();
int count = 0;
for (Collection<DataNode> each : routeResult.getOriginalDataNodes()) {
......@@ -66,6 +79,7 @@ public final class RouteSQLRewriteEngine {
}
count++;
}
result.addAll(((GroupedParameterBuilder) parameterBuilder).getOnDuplicateKeyUpdateParametersBuilder().getParameters());
return result;
}
......
......@@ -35,16 +35,18 @@ public final class GroupedParameterBuilder implements ParameterBuilder {
private final List<StandardParameterBuilder> parameterBuilders;
@Getter
private final List<Object> onDuplicateKeyUpdateAddedParameters = new LinkedList<>();
private final StandardParameterBuilder onDuplicateKeyUpdateParametersBuilder;
@Setter
private String derivedColumnName;
public GroupedParameterBuilder(final List<List<Object>> groupedParameters) {
public GroupedParameterBuilder(final List<List<Object>> groupedParameters, final List<Object> onDuplicateKeyUpdateParameters) {
parameterBuilders = new ArrayList<>(groupedParameters.size());
for (List<Object> each : groupedParameters) {
parameterBuilders.add(new StandardParameterBuilder(each));
}
onDuplicateKeyUpdateParametersBuilder = new StandardParameterBuilder(onDuplicateKeyUpdateParameters);
}
@Override
......@@ -53,9 +55,6 @@ public final class GroupedParameterBuilder implements ParameterBuilder {
for (int i = 0; i < parameterBuilders.size(); i++) {
result.addAll(getParameters(i));
}
if (!onDuplicateKeyUpdateAddedParameters.isEmpty()) {
result.addAll(onDuplicateKeyUpdateAddedParameters);
}
return result;
}
......
......@@ -20,9 +20,11 @@ package org.apache.shardingsphere.underlying.rewrite.impl;
import org.apache.shardingsphere.underlying.rewrite.parameter.builder.impl.GroupedParameterBuilder;
import org.junit.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;
......@@ -31,10 +33,40 @@ public final class GroupedParameterBuilderTest {
@Test
public void assertGetParameters() {
GroupedParameterBuilder actual = new GroupedParameterBuilder(createGroupedParameters());
GroupedParameterBuilder actual = new GroupedParameterBuilder(createGroupedParameters(), new ArrayList<>());
assertThat(actual.getParameters(), is(Arrays.<Object>asList(3, 4, 5, 6)));
}
@Test
public void assertGetParametersWithOnDuplicateKeyParameters() {
GroupedParameterBuilder actual = new GroupedParameterBuilder(createGroupedParameters(), createOnDuplicateKeyUpdateParameters());
assertThat(actual.getParameters(), is(Arrays.<Object>asList(3, 4, 5, 6)));
assertThat(actual.getOnDuplicateKeyUpdateParametersBuilder().getParameters(), is(Arrays.<Object>asList(7, 8)));
}
@Test
public void assertGetOnDuplicateKeyParametersWithModify() {
GroupedParameterBuilder actual = new GroupedParameterBuilder(new LinkedList<>(), createOnDuplicateKeyUpdateParameters());
actual.getOnDuplicateKeyUpdateParametersBuilder().addReplacedParameters(0, 77);
actual.getOnDuplicateKeyUpdateParametersBuilder().addReplacedParameters(1, 88);
actual.getOnDuplicateKeyUpdateParametersBuilder().addAddedParameters(0, Arrays.asList(66, -1));
actual.getOnDuplicateKeyUpdateParametersBuilder().addAddedParameters(2, Arrays.asList(99, 110));
actual.getOnDuplicateKeyUpdateParametersBuilder().addRemovedParameters(1);
assertThat(actual.getOnDuplicateKeyUpdateParametersBuilder().getParameters(), is(Arrays.<Object>asList(66, 77, 88, 99, 110)));
}
@Test
public void assertGetDerivedColumnName() {
GroupedParameterBuilder actual = new GroupedParameterBuilder(createGroupedParameters(), createOnDuplicateKeyUpdateParameters());
String derivedColumnName = "derivedColumnName";
actual.setDerivedColumnName(derivedColumnName);
assertThat(actual.getDerivedColumnName(), is(Optional.of(derivedColumnName)));
}
private List<Object> createOnDuplicateKeyUpdateParameters() {
return new LinkedList<>(Arrays.asList(7, 8));
}
private List<List<Object>> createGroupedParameters() {
List<List<Object>> result = new LinkedList<>();
result.add(Arrays.asList(3, 4));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册