From adf9c473f6c6a0b2541efb87594d44267ad5d119 Mon Sep 17 00:00:00 2001 From: Neil Dong Date: Fri, 15 May 2020 18:05:23 +0800 Subject: [PATCH] Issue5423 and issue5465 (#5590) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * BugFix:1.ON DUPLICATE KEY UPDATE Sub-Clause paramter dropped https://github.com/apache/shardingsphere/issues/5210 BugFix: 2.ON DUPLICATE KEY UPDATE Sub-Clause encrypt logic: missing assistedQueryColumn. * Add Apache License to org.apache.shardingsphere.sql.parser.binder.segment.insert.values.OnDuplicateUpdateContext * 1.Add Tests for org.apache.shardingsphere.sql.parser.binder.segment.insert.values.OnDuplicateUpdateContext 2.Improve Tests for GroupedParameterBuilder * 1.Improve Tests coverage of InsertStatementContextTest * 1.Remove EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter implement of QueryWithCipherColumnAware * 1.Improve Tests coverage of GroupedParameterBuilder ,OnDuplicateUpdateContext * 1.Improve Tests coverage of GroupedParameterBuilder * Add testcase , sharding test case will fail. * Add testcase , sharding test case will fail. * BugFix: 1.https://github.com/apache/shardingsphere/issues/5210 ON DUPLICATE KEY UPDATE Sub-Clause paramter dropped BugFix: 2.https://github.com/apache/shardingsphere/issues/5465 1)ON DUPLICATE KEY UPDATE Sub-Clause encrypt logic: missing assistedQueryColumn. 2)REMOVE wrong implements of QueryWithCipherColumnAware in org.apache.shardingsphere.encrypt.rewrite.parameter.impl.EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter * merge upstream * improve code style. * improve code style. * improve code style. * improve code style. * improve code style. * improve code style. * Remove redundant blank lines * Update OnDuplicateUpdateContext.java document format adjusted. * Update OnDuplicateUpdateContextTest.java ajust code style --- ...licateKeyUpdateValueParameterRewriter.java | 60 ++++---- .../EncryptInsertOnUpdateTokenGenerator.java | 15 ++ .../src/test/resources/encrypt/insert.xml | 9 +- .../src/test/resources/sharding/insert.xml | 6 + .../ShardingSpherePreparedStatementTest.java | 83 ++++++++++- .../values/OnDuplicateUpdateContext.java | 111 ++++++++++++++ .../statement/dml/InsertStatementContext.java | 45 ++++-- .../values/OnDuplicateUpdateContextTest.java | 136 ++++++++++++++++++ .../impl/InsertStatementContextTest.java | 36 ++++- .../mysql/visitor/impl/MySQLDMLVisitor.java | 1 + .../rewrite/context/SQLRewriteContext.java | 5 +- .../engine/GenericSQLRewriteEngine.java | 24 +++- .../rewrite/engine/RouteSQLRewriteEngine.java | 16 ++- .../builder/impl/GroupedParameterBuilder.java | 9 +- .../impl/GroupedParameterBuilderTest.java | 36 ++++- 15 files changed, 535 insertions(+), 57 deletions(-) create mode 100644 shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/segment/insert/values/OnDuplicateUpdateContext.java create mode 100644 shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/test/java/org/apache/shardingsphere/sql/parser/binder/segment/insert/values/OnDuplicateUpdateContextTest.java diff --git a/encrypt-core/encrypt-core-rewrite/src/main/java/org/apache/shardingsphere/encrypt/rewrite/parameter/impl/EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter.java b/encrypt-core/encrypt-core-rewrite/src/main/java/org/apache/shardingsphere/encrypt/rewrite/parameter/impl/EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter.java index cce4eea020..540d8f06cd 100644 --- a/encrypt-core/encrypt-core-rewrite/src/main/java/org/apache/shardingsphere/encrypt/rewrite/parameter/impl/EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter.java +++ b/encrypt-core/encrypt-core-rewrite/src/main/java/org/apache/shardingsphere/encrypt/rewrite/parameter/impl/EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter.java @@ -19,19 +19,17 @@ package org.apache.shardingsphere.encrypt.rewrite.parameter.impl; import com.google.common.base.Preconditions; import lombok.Setter; -import org.apache.shardingsphere.encrypt.rewrite.aware.QueryWithCipherColumnAware; import org.apache.shardingsphere.encrypt.rewrite.parameter.EncryptParameterRewriter; import org.apache.shardingsphere.encrypt.strategy.spi.Encryptor; +import org.apache.shardingsphere.encrypt.strategy.spi.QueryAssistedEncryptor; +import org.apache.shardingsphere.sql.parser.binder.segment.insert.values.OnDuplicateUpdateContext; import org.apache.shardingsphere.sql.parser.binder.statement.SQLStatementContext; import org.apache.shardingsphere.sql.parser.binder.statement.dml.InsertStatementContext; -import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.AssignmentSegment; -import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.OnDuplicateKeyColumnsSegment; -import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.ExpressionSegment; -import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.ParameterMarkerExpressionSegment; import org.apache.shardingsphere.underlying.rewrite.parameter.builder.ParameterBuilder; import org.apache.shardingsphere.underlying.rewrite.parameter.builder.impl.GroupedParameterBuilder; import java.util.Collection; +import java.util.LinkedList; import java.util.List; import java.util.Optional; @@ -39,9 +37,7 @@ import java.util.Optional; * Insert on duplicate key update parameter rewriter for encrypt. */ @Setter -public final class EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter extends EncryptParameterRewriter implements QueryWithCipherColumnAware { - - private boolean queryWithCipherColumn; +public final class EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter extends EncryptParameterRewriter { @Override protected boolean isNeedRewriteForEncrypt(final SQLStatementContext sqlStatementContext) { @@ -51,30 +47,34 @@ public final class EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter exten @Override public void rewrite(final ParameterBuilder parameterBuilder, final InsertStatementContext insertStatementContext, final List parameters) { String tableName = insertStatementContext.getSqlStatement().getTable().getTableName().getIdentifier().getValue(); - Preconditions.checkState(insertStatementContext.getSqlStatement().getOnDuplicateKeyColumns().isPresent()); - OnDuplicateKeyColumnsSegment onDuplicateKeyColumnsSegment = insertStatementContext.getSqlStatement().getOnDuplicateKeyColumns().get(); - Collection onDuplicateKeyColumnsSegments = onDuplicateKeyColumnsSegment.getColumns(); - if (onDuplicateKeyColumnsSegments.isEmpty()) { - return; - } GroupedParameterBuilder groupedParameterBuilder = (GroupedParameterBuilder) parameterBuilder; - for (AssignmentSegment each : onDuplicateKeyColumnsSegments) { - ExpressionSegment expressionSegment = each.getValue(); - Object cipherColumnValue; - Object plainColumnValue = null; - if (expressionSegment instanceof ParameterMarkerExpressionSegment) { - plainColumnValue = parameters.get(((ParameterMarkerExpressionSegment) expressionSegment).getParameterMarkerIndex()); - } - if (queryWithCipherColumn) { - Optional encryptor = getEncryptRule().findEncryptor(tableName, each.getColumn().getIdentifier().getValue()); - if (encryptor.isPresent()) { - cipherColumnValue = encryptor.get().encrypt(plainColumnValue); - groupedParameterBuilder.getOnDuplicateKeyUpdateAddedParameters().add(cipherColumnValue); + OnDuplicateUpdateContext onDuplicateKeyUpdateValueContext = insertStatementContext.getOnDuplicateKeyUpdateValueContext(); + for (int index = 0; index < onDuplicateKeyUpdateValueContext.getValueExpressions().size(); index++) { + final int columnIndex = index; + String encryptLogicColumnName = onDuplicateKeyUpdateValueContext.getColumn(columnIndex).getIdentifier().getValue(); + Optional encryptorOptional = getEncryptRule().findEncryptor(tableName, encryptLogicColumnName); + encryptorOptional.ifPresent(encryptor -> { + Object plainColumnValue = onDuplicateKeyUpdateValueContext.getValue(columnIndex); + Object cipherColumnValue = encryptorOptional.get().encrypt(plainColumnValue); + groupedParameterBuilder.getOnDuplicateKeyUpdateParametersBuilder().addReplacedParameters(columnIndex, cipherColumnValue); + Collection addedParameters = new LinkedList<>(); + if (encryptor instanceof QueryAssistedEncryptor) { + Optional assistedColumnName = getEncryptRule().findAssistedQueryColumn(tableName, encryptLogicColumnName); + Preconditions.checkArgument(assistedColumnName.isPresent(), "Can not find assisted query Column Name"); + addedParameters.add(((QueryAssistedEncryptor) encryptor).queryAssistedEncrypt(plainColumnValue.toString())); + } + + if (getEncryptRule().findPlainColumn(tableName, encryptLogicColumnName).isPresent()) { + addedParameters.add(plainColumnValue); + } + + if (!addedParameters.isEmpty()) { + if (!groupedParameterBuilder.getOnDuplicateKeyUpdateParametersBuilder().getAddedIndexAndParameters().containsKey(columnIndex + 1)) { + groupedParameterBuilder.getOnDuplicateKeyUpdateParametersBuilder().getAddedIndexAndParameters().put(columnIndex + 1, new LinkedList<>()); + } + groupedParameterBuilder.getOnDuplicateKeyUpdateParametersBuilder().getAddedIndexAndParameters().get(columnIndex + 1).addAll(addedParameters); } - } - if (null != plainColumnValue) { - groupedParameterBuilder.getOnDuplicateKeyUpdateAddedParameters().add(plainColumnValue); - } + }); } } } diff --git a/encrypt-core/encrypt-core-rewrite/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptInsertOnUpdateTokenGenerator.java b/encrypt-core/encrypt-core-rewrite/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptInsertOnUpdateTokenGenerator.java index 2870ea8878..0dba5e82d5 100644 --- a/encrypt-core/encrypt-core-rewrite/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptInsertOnUpdateTokenGenerator.java +++ b/encrypt-core/encrypt-core-rewrite/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/impl/EncryptInsertOnUpdateTokenGenerator.java @@ -77,6 +77,7 @@ public final class EncryptInsertOnUpdateTokenGenerator extends BaseEncryptSQLTok EncryptParameterAssignmentToken result = new EncryptParameterAssignmentToken(assignmentSegment.getColumn().getStartIndex(), assignmentSegment.getStopIndex()); String columnName = assignmentSegment.getColumn().getIdentifier().getValue(); addCipherColumn(tableName, columnName, result); + addAssistedQueryColumn(tableName, columnName, result); addPlainColumn(tableName, columnName, result); return result; } @@ -84,6 +85,7 @@ public final class EncryptInsertOnUpdateTokenGenerator extends BaseEncryptSQLTok private EncryptAssignmentToken generateLiteralSQLToken(final String tableName, final AssignmentSegment assignmentSegment) { EncryptLiteralAssignmentToken result = new EncryptLiteralAssignmentToken(assignmentSegment.getColumn().getStartIndex(), assignmentSegment.getStopIndex()); addCipherAssignment(tableName, assignmentSegment, result); + addAssistedQueryAssignment(tableName, assignmentSegment, result); addPlainAssignment(tableName, assignmentSegment, result); return result; } @@ -91,6 +93,10 @@ public final class EncryptInsertOnUpdateTokenGenerator extends BaseEncryptSQLTok private void addCipherColumn(final String tableName, final String columnName, final EncryptParameterAssignmentToken token) { token.addColumnName(getEncryptRule().getCipherColumn(tableName, columnName)); } + + private void addAssistedQueryColumn(final String tableName, final String columnName, final EncryptParameterAssignmentToken token) { + getEncryptRule().findAssistedQueryColumn(tableName, columnName).ifPresent(token::addColumnName); + } private void addPlainColumn(final String tableName, final String columnName, final EncryptParameterAssignmentToken token) { getEncryptRule().findPlainColumn(tableName, columnName).ifPresent(token::addColumnName); @@ -102,6 +108,15 @@ public final class EncryptInsertOnUpdateTokenGenerator extends BaseEncryptSQLTok token.addAssignment(getEncryptRule().getCipherColumn(tableName, assignmentSegment.getColumn().getIdentifier().getValue()), cipherValue); } + private void addAssistedQueryAssignment(final String tableName, final AssignmentSegment assignmentSegment, final EncryptLiteralAssignmentToken token) { + getEncryptRule().findAssistedQueryColumn(tableName, assignmentSegment.getColumn().getIdentifier().getValue()).ifPresent(assistedQueryColumn -> { + Object originalValue = ((LiteralExpressionSegment) assignmentSegment.getValue()).getLiterals(); + Object assistedQueryValue = getEncryptRule().getEncryptAssistedQueryValues(tableName, assignmentSegment.getColumn().getIdentifier().getValue(), Collections.singletonList(originalValue)) + .iterator().next(); + token.addAssignment(assistedQueryColumn, assistedQueryValue); + }); + } + private void addPlainAssignment(final String tableName, final AssignmentSegment assignmentSegment, final EncryptLiteralAssignmentToken token) { Object originalValue = ((LiteralExpressionSegment) assignmentSegment.getValue()).getLiterals(); getEncryptRule().findPlainColumn(tableName, assignmentSegment.getColumn().getIdentifier().getValue()).ifPresent(plainColumn -> token.addAssignment(plainColumn, originalValue)); diff --git a/encrypt-core/encrypt-core-rewrite/src/test/resources/encrypt/insert.xml b/encrypt-core/encrypt-core-rewrite/src/test/resources/encrypt/insert.xml index a8aa1d5add..bc41a7963f 100644 --- a/encrypt-core/encrypt-core-rewrite/src/test/resources/encrypt/insert.xml +++ b/encrypt-core/encrypt-core-rewrite/src/test/resources/encrypt/insert.xml @@ -44,12 +44,17 @@ - + - + + + + + + diff --git a/sharding-core/sharding-core-rewrite/src/test/resources/sharding/insert.xml b/sharding-core/sharding-core-rewrite/src/test/resources/sharding/insert.xml index ff5c1e5779..a2ed3ca4a6 100644 --- a/sharding-core/sharding-core-rewrite/src/test/resources/sharding/insert.xml +++ b/sharding-core/sharding-core-rewrite/src/test/resources/sharding/insert.xml @@ -125,4 +125,10 @@ + + + + + + diff --git a/sharding-jdbc/sharding-jdbc-core/src/test/java/org/apache/shardingsphere/shardingjdbc/jdbc/core/statement/ShardingSpherePreparedStatementTest.java b/sharding-jdbc/sharding-jdbc-core/src/test/java/org/apache/shardingsphere/shardingjdbc/jdbc/core/statement/ShardingSpherePreparedStatementTest.java index e41a4c3eaa..70d2bbe0e1 100644 --- a/sharding-jdbc/sharding-jdbc-core/src/test/java/org/apache/shardingsphere/shardingjdbc/jdbc/core/statement/ShardingSpherePreparedStatementTest.java +++ b/sharding-jdbc/sharding-jdbc-core/src/test/java/org/apache/shardingsphere/shardingjdbc/jdbc/core/statement/ShardingSpherePreparedStatementTest.java @@ -44,10 +44,14 @@ public final class ShardingSpherePreparedStatementTest extends AbstractShardingJ private static final String INSERT_WITHOUT_GENERATE_KEY_SQL = "INSERT INTO t_order_item (order_id, user_id, status) VALUES (?, ?, ?)"; + private static final String INSERT_ON_DUPLICATE_KEY_SQL = "INSERT INTO t_order_item (item_id, order_id, user_id, status) VALUES (?, ?, ?, ?), (?, ?, ?, ?) ON DUPLICATE KEY UPDATE status = ?"; + private static final String SELECT_SQL_WITHOUT_PARAMETER_MARKER = "SELECT item_id FROM t_order_item WHERE user_id = %d AND order_id= %s AND status = 'BATCH'"; private static final String SELECT_SQL_WITH_PARAMETER_MARKER = "SELECT item_id FROM t_order_item WHERE user_id = ? AND order_id= ? AND status = 'BATCH'"; + private static final String SELECT_SQL_WITH_PARAMETER_MARKER_RETURN_STATUS = "SELECT item_id, user_id, status FROM t_order_item WHERE order_id= ? AND user_id = ?"; + private static final String UPDATE_SQL = "UPDATE t_order SET status = ? WHERE user_id = ? AND order_id = ?"; private static final String UPDATE_BATCH_SQL = "UPDATE t_order SET status=? WHERE status=?"; @@ -83,7 +87,7 @@ public final class ShardingSpherePreparedStatementTest extends AbstractShardingJ } } } - + @Ignore @Test public void assertMultiValuesWithGenerateShardingKeyColumn() throws SQLException { @@ -126,7 +130,7 @@ public final class ShardingSpherePreparedStatementTest extends AbstractShardingJ } } } - + @Ignore @Test public void assertAddBatchMultiValuesWithGenerateShardingKeyColumn() throws SQLException { @@ -316,6 +320,75 @@ public final class ShardingSpherePreparedStatementTest extends AbstractShardingJ } } + @Test + public void assertAddOnDuplicateKey() throws SQLException { + int itemId = 1; + int userId1 = 101; + int userId2 = 102; + int orderId = 200; + String status = "init"; + String updatedStatus = "updated on duplicate key"; + try (Connection connection = getShardingSphereDataSource().getConnection(); + PreparedStatement preparedStatement = connection.prepareStatement(INSERT_ON_DUPLICATE_KEY_SQL); + PreparedStatement queryStatement = connection.prepareStatement(SELECT_SQL_WITH_PARAMETER_MARKER_RETURN_STATUS)) { + preparedStatement.setInt(1, itemId); + preparedStatement.setInt(2, orderId); + preparedStatement.setInt(3, userId1); + preparedStatement.setString(4, status); + preparedStatement.setInt(5, itemId); + preparedStatement.setInt(6, orderId); + preparedStatement.setInt(7, userId2); + preparedStatement.setString(8, status); + preparedStatement.setString(9, updatedStatus); + int result = preparedStatement.executeUpdate(); + assertThat(result, is(2)); + queryStatement.setInt(1, orderId); + queryStatement.setInt(2, userId1); + try (ResultSet resultSet = queryStatement.executeQuery()) { + assertTrue(resultSet.next()); + assertThat(resultSet.getInt(2), is(userId1)); + assertThat(resultSet.getString(3), is(status)); + } + queryStatement.setInt(1, orderId); + queryStatement.setInt(2, userId2); + try (ResultSet resultSet = queryStatement.executeQuery()) { + assertTrue(resultSet.next()); + assertThat(resultSet.getInt(2), is(userId2)); + assertThat(resultSet.getString(3), is(status)); + } + } + + try (Connection connection = getShardingSphereDataSource().getConnection(); + PreparedStatement preparedStatement = connection.prepareStatement(INSERT_ON_DUPLICATE_KEY_SQL); + PreparedStatement queryStatement = connection.prepareStatement(SELECT_SQL_WITH_PARAMETER_MARKER_RETURN_STATUS)) { + preparedStatement.setInt(1, itemId); + preparedStatement.setInt(2, orderId); + preparedStatement.setInt(3, userId1); + preparedStatement.setString(4, status); + preparedStatement.setInt(5, itemId); + preparedStatement.setInt(6, orderId); + preparedStatement.setInt(7, userId2); + preparedStatement.setString(8, status); + preparedStatement.setString(9, updatedStatus); + int result = preparedStatement.executeUpdate(); + assertThat(result, is(2)); + queryStatement.setInt(1, orderId); + queryStatement.setInt(2, userId1); + try (ResultSet resultSet = queryStatement.executeQuery()) { + assertTrue(resultSet.next()); + assertThat(resultSet.getInt(2), is(userId1)); + assertThat(resultSet.getString(3), is(updatedStatus)); + } + queryStatement.setInt(1, orderId); + queryStatement.setInt(2, userId2); + try (ResultSet resultSet = queryStatement.executeQuery()) { + assertTrue(resultSet.next()); + assertThat(resultSet.getInt(2), is(userId2)); + assertThat(resultSet.getString(3), is(updatedStatus)); + } + } + } + @Test public void assertUpdateBatch() throws SQLException { try ( @@ -337,7 +410,7 @@ public final class ShardingSpherePreparedStatementTest extends AbstractShardingJ assertThat(result[2], is(4)); } } - + @Test public void assertExecuteGetResultSet() throws SQLException { try (PreparedStatement preparedStatement = getShardingSphereDataSource().getConnection().prepareStatement(UPDATE_SQL)) { @@ -348,7 +421,7 @@ public final class ShardingSpherePreparedStatementTest extends AbstractShardingJ assertNull(preparedStatement.getResultSet()); } } - + @Test public void assertExecuteUpdateGetResultSet() throws SQLException { try (PreparedStatement preparedStatement = getShardingSphereDataSource().getConnection().prepareStatement(UPDATE_SQL)) { @@ -359,7 +432,7 @@ public final class ShardingSpherePreparedStatementTest extends AbstractShardingJ assertNull(preparedStatement.getResultSet()); } } - + @Test public void assertClearBatch() throws SQLException { try ( diff --git a/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/segment/insert/values/OnDuplicateUpdateContext.java b/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/segment/insert/values/OnDuplicateUpdateContext.java new file mode 100644 index 0000000000..5da791b4fe --- /dev/null +++ b/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/segment/insert/values/OnDuplicateUpdateContext.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shardingsphere.sql.parser.binder.segment.insert.values; + +import lombok.Getter; +import lombok.ToString; +import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.AssignmentSegment; +import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment; +import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.ExpressionSegment; +import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.LiteralExpressionSegment; +import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.ParameterMarkerExpressionSegment; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +@Getter +@ToString +public class OnDuplicateUpdateContext { + private final int parametersCount; + + private final List valueExpressions; + + private final List parameters; + + private final List columns; + + public OnDuplicateUpdateContext(final Collection assignments, final List parameters, final int parametersOffset) { + List expressionSegments = assignments.stream().map(AssignmentSegment::getValue).collect(Collectors.toList()); + parametersCount = calculateParametersCount(expressionSegments); + valueExpressions = getValueExpressions(expressionSegments); + this.parameters = getParameters(parameters, parametersOffset); + columns = assignments.stream().map(AssignmentSegment::getColumn).collect(Collectors.toList()); + } + + private int calculateParametersCount(final Collection assignments) { + int result = 0; + for (ExpressionSegment each : assignments) { + if (each instanceof ParameterMarkerExpressionSegment) { + result++; + } + } + return result; + } + + private List getValueExpressions(final Collection assignments) { + List result = new ArrayList<>(assignments.size()); + result.addAll(assignments); + return result; + } + + private List getParameters(final List parameters, final int parametersOffset) { + if (0 == parametersCount) { + return Collections.emptyList(); + } + List result = new ArrayList<>(parametersCount); + result.addAll(parameters.subList(parametersOffset, parametersOffset + parametersCount)); + return result; + } + + /** + * Get value. + * + * @param index index + * @return value + */ + public Object getValue(final int index) { + ExpressionSegment valueExpression = valueExpressions.get(index); + return valueExpression instanceof ParameterMarkerExpressionSegment ? parameters.get(getParameterIndex(valueExpression)) : ((LiteralExpressionSegment) valueExpression).getLiterals(); + } + + private int getParameterIndex(final ExpressionSegment valueExpression) { + int result = 0; + for (ExpressionSegment each : valueExpressions) { + if (valueExpression == each) { + return result; + } + if (each instanceof ParameterMarkerExpressionSegment) { + result++; + } + } + throw new IllegalArgumentException("Can not get parameter index."); + } + + /** + * Get on duplicate key update column by index of this clause. + * + * @param index index + * @return columnSegment + */ + public ColumnSegment getColumn(final int index) { + return columns.get(index); + } +} diff --git a/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/statement/dml/InsertStatementContext.java b/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/statement/dml/InsertStatementContext.java index b47e8dc5e6..b10bae5e6e 100644 --- a/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/statement/dml/InsertStatementContext.java +++ b/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/statement/dml/InsertStatementContext.java @@ -23,19 +23,23 @@ import org.apache.shardingsphere.sql.parser.binder.metadata.schema.SchemaMetaDat import org.apache.shardingsphere.sql.parser.binder.segment.insert.keygen.GeneratedKeyContext; import org.apache.shardingsphere.sql.parser.binder.segment.insert.keygen.engine.GeneratedKeyContextEngine; import org.apache.shardingsphere.sql.parser.binder.segment.insert.values.InsertValueContext; +import org.apache.shardingsphere.sql.parser.binder.segment.insert.values.OnDuplicateUpdateContext; import org.apache.shardingsphere.sql.parser.binder.segment.table.TablesContext; import org.apache.shardingsphere.sql.parser.binder.statement.CommonSQLStatementContext; import org.apache.shardingsphere.sql.parser.binder.type.TableAvailable; +import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.AssignmentSegment; import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.ExpressionSegment; import org.apache.shardingsphere.sql.parser.sql.segment.generic.table.SimpleTableSegment; import org.apache.shardingsphere.sql.parser.sql.statement.dml.InsertStatement; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; /** * Insert SQL statement context. @@ -50,30 +54,43 @@ public final class InsertStatementContext extends CommonSQLStatementContext insertValueContexts; + private final OnDuplicateUpdateContext onDuplicateKeyUpdateValueContext; + private final GeneratedKeyContext generatedKeyContext; public InsertStatementContext(final SchemaMetaData schemaMetaData, final List parameters, final InsertStatement sqlStatement) { super(sqlStatement); tablesContext = new TablesContext(sqlStatement.getTable()); columnNames = sqlStatement.useDefaultColumns() ? schemaMetaData.getAllColumnNames(sqlStatement.getTable().getTableName().getIdentifier().getValue()) : sqlStatement.getColumnNames(); - insertValueContexts = getInsertValueContexts(parameters); + AtomicInteger parametersOffset = new AtomicInteger(0); + insertValueContexts = getInsertValueContexts(parameters, parametersOffset); + onDuplicateKeyUpdateValueContext = getOnDuplicateKeyUpdateValueContext(parameters, parametersOffset).orElse(null); generatedKeyContext = new GeneratedKeyContextEngine(schemaMetaData).createGenerateKeyContext(parameters, sqlStatement).orElse(null); } - private List getInsertValueContexts(final List parameters) { + private List getInsertValueContexts(final List parameters, final AtomicInteger parametersOffset) { List result = new LinkedList<>(); - int parametersOffset = 0; for (Collection each : getSqlStatement().getAllValueExpressions()) { - InsertValueContext insertValueContext = new InsertValueContext(each, parameters, parametersOffset); + InsertValueContext insertValueContext = new InsertValueContext(each, parameters, parametersOffset.get()); result.add(insertValueContext); - parametersOffset += insertValueContext.getParametersCount(); + parametersOffset.addAndGet(insertValueContext.getParametersCount()); } return result; } + private Optional getOnDuplicateKeyUpdateValueContext(final List parameters, final AtomicInteger parametersOffset) { + if (!getSqlStatement().getOnDuplicateKeyColumns().isPresent()) { + return Optional.empty(); + } + Collection onDuplicateKeyColumns = getSqlStatement().getOnDuplicateKeyColumns().get().getColumns(); + OnDuplicateUpdateContext onDuplicateUpdateContext = new OnDuplicateUpdateContext(onDuplicateKeyColumns, parameters, parametersOffset.get()); + parametersOffset.addAndGet(onDuplicateUpdateContext.getParametersCount()); + return Optional.of(onDuplicateUpdateContext); + } + /** * Get column names for descending order. - * + * * @return column names for descending order */ public Iterator getDescendingColumnNames() { @@ -82,7 +99,7 @@ public final class InsertStatementContext extends CommonSQLStatementContext> getGroupedParameters() { @@ -93,9 +110,21 @@ public final class InsertStatementContext extends CommonSQLStatementContext getOnDuplicateKeyUpdateParameters() { + if (null == onDuplicateKeyUpdateValueContext) { + return new ArrayList<>(0); + } + return onDuplicateKeyUpdateValueContext.getParameters(); + } + /** * Get generated key context. - * + * * @return generated key context */ public Optional getGeneratedKeyContext() { diff --git a/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/test/java/org/apache/shardingsphere/sql/parser/binder/segment/insert/values/OnDuplicateUpdateContextTest.java b/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/test/java/org/apache/shardingsphere/sql/parser/binder/segment/insert/values/OnDuplicateUpdateContextTest.java new file mode 100644 index 0000000000..501cef9a7b --- /dev/null +++ b/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/test/java/org/apache/shardingsphere/sql/parser/binder/segment/insert/values/OnDuplicateUpdateContextTest.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shardingsphere.sql.parser.binder.segment.insert.values; + +import com.google.common.collect.Lists; +import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.AssignmentSegment; +import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment; +import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.ExpressionSegment; +import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.LiteralExpressionSegment; +import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.ParameterMarkerExpressionSegment; +import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.SimpleExpressionSegment; +import org.apache.shardingsphere.sql.parser.sql.value.identifier.IdentifierValue; +import org.junit.Assert; +import org.junit.Test; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +import static org.hamcrest.CoreMatchers.is; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +public final class OnDuplicateUpdateContextTest { + + @SuppressWarnings("unchecked") + @Test + public void assertInstanceConstructedOk() throws NoSuchMethodException, InvocationTargetException, IllegalAccessException { + Collection assignments = Lists.newArrayList(); + List parameters = Collections.emptyList(); + int parametersOffset = 0; + OnDuplicateUpdateContext onDuplicateUpdateContext = new OnDuplicateUpdateContext(assignments, parameters, parametersOffset); + Method calculateParametersCountMethod = OnDuplicateUpdateContext.class.getDeclaredMethod("calculateParametersCount", Collection.class); + calculateParametersCountMethod.setAccessible(true); + int calculateParametersCountResult = (int) calculateParametersCountMethod.invoke(onDuplicateUpdateContext, new Object[]{assignments}); + assertThat(onDuplicateUpdateContext.getParametersCount(), is(calculateParametersCountResult)); + Method getValueExpressionsMethod = OnDuplicateUpdateContext.class.getDeclaredMethod("getValueExpressions", Collection.class); + getValueExpressionsMethod.setAccessible(true); + List getValueExpressionsResult = (List) getValueExpressionsMethod.invoke(onDuplicateUpdateContext, new Object[]{assignments}); + assertThat(onDuplicateUpdateContext.getValueExpressions(), is(getValueExpressionsResult)); + Method getParametersMethod = OnDuplicateUpdateContext.class.getDeclaredMethod("getParameters", List.class, int.class); + getParametersMethod.setAccessible(true); + List getParametersResult = (List) getParametersMethod.invoke(onDuplicateUpdateContext, new Object[]{parameters, parametersOffset}); + assertThat(onDuplicateUpdateContext.getParameters(), is(getParametersResult)); + } + + @Test + public void assertGetValueWhenParameterMarker() { + Collection assignments = makeParameterMarkerExpressionAssignmentSegment(); + String parameterValue1 = "test1"; + String parameterValue2 = "test2"; + List parameters = Lists.newArrayList(parameterValue1, parameterValue2); + int parametersOffset = 0; + OnDuplicateUpdateContext onDuplicateUpdateContext = new OnDuplicateUpdateContext(assignments, parameters, parametersOffset); + Object valueFromInsertValueContext1 = onDuplicateUpdateContext.getValue(0); + assertThat(valueFromInsertValueContext1, is(parameterValue1)); + Object valueFromInsertValueContext2 = onDuplicateUpdateContext.getValue(1); + assertThat(valueFromInsertValueContext2, is(parameterValue2)); + } + + private Collection makeParameterMarkerExpressionAssignmentSegment() { + ParameterMarkerExpressionSegment parameterMarkerExpressionSegment = new ParameterMarkerExpressionSegment(0, 10, 5); + AssignmentSegment assignmentSegment1 = makeAssignmentSegment(parameterMarkerExpressionSegment); + ParameterMarkerExpressionSegment parameterMarkerExpressionSegment2 = new ParameterMarkerExpressionSegment(0, 10, 6); + AssignmentSegment assignmentSegment2 = makeAssignmentSegment(parameterMarkerExpressionSegment2); + return Lists.newArrayList(assignmentSegment1, assignmentSegment2); + } + + @Test + public void assertGetValueWhenLiteralExpressionSegment() { + Object literalObject = new Object(); + Collection assignments = makeLiteralExpressionSegment(literalObject); + List parameters = Collections.emptyList(); + OnDuplicateUpdateContext onDuplicateUpdateContext = new OnDuplicateUpdateContext(assignments, parameters, 0); + Object valueFromInsertValueContext = onDuplicateUpdateContext.getValue(0); + assertThat(valueFromInsertValueContext, is(literalObject)); + } + + private Collection makeLiteralExpressionSegment(final Object literalObject) { + LiteralExpressionSegment parameterLiteralExpression = new LiteralExpressionSegment(0, 10, literalObject); + AssignmentSegment assignmentSegment = makeAssignmentSegment(parameterLiteralExpression); + return Collections.singleton(assignmentSegment); + } + + private AssignmentSegment makeAssignmentSegment(final SimpleExpressionSegment expressionSegment) { + int doesNotMatterLexicalIndex = 0; + String doesNotMatterColumnName = "columnNameStr"; + ColumnSegment column = new ColumnSegment(doesNotMatterLexicalIndex, doesNotMatterLexicalIndex, new IdentifierValue(doesNotMatterColumnName)); + return new AssignmentSegment(doesNotMatterLexicalIndex, doesNotMatterLexicalIndex, column, expressionSegment); + } + + @Test + public void assertGetParameterIndex() throws NoSuchMethodException, IllegalAccessException { + Collection assignments = Lists.newArrayList(); + List parameters = Collections.emptyList(); + int parametersOffset = 0; + OnDuplicateUpdateContext onDuplicateUpdateContext = new OnDuplicateUpdateContext(assignments, parameters, parametersOffset); + Method getParameterIndexMethod = OnDuplicateUpdateContext.class.getDeclaredMethod("getParameterIndex", ExpressionSegment.class); + getParameterIndexMethod.setAccessible(true); + ParameterMarkerExpressionSegment notExistsExpressionSegment = new ParameterMarkerExpressionSegment(0, 0, 0); + Throwable targetException = null; + try { + getParameterIndexMethod.invoke(onDuplicateUpdateContext, notExistsExpressionSegment); + } catch (InvocationTargetException e) { + targetException = e.getTargetException(); + } + assertTrue("expected throw IllegalArgumentException", targetException instanceof IllegalArgumentException); + } + + @Test + public void assertGetColumn() { + Object literalObject = new Object(); + Collection assignments = makeLiteralExpressionSegment(literalObject); + List parameters = Collections.emptyList(); + OnDuplicateUpdateContext onDuplicateUpdateContext = new OnDuplicateUpdateContext(assignments, parameters, 0); + ColumnSegment column = onDuplicateUpdateContext.getColumn(0); + Assert.assertThat(column, is(assignments.iterator().next().getColumn())); + } +} diff --git a/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/test/java/org/apache/shardingsphere/sql/parser/binder/statement/impl/InsertStatementContextTest.java b/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/test/java/org/apache/shardingsphere/sql/parser/binder/statement/impl/InsertStatementContextTest.java index f27f3ddfdd..84b4118774 100644 --- a/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/test/java/org/apache/shardingsphere/sql/parser/binder/statement/impl/InsertStatementContextTest.java +++ b/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/test/java/org/apache/shardingsphere/sql/parser/binder/statement/impl/InsertStatementContextTest.java @@ -19,9 +19,11 @@ package org.apache.shardingsphere.sql.parser.binder.statement.impl; import org.apache.shardingsphere.sql.parser.binder.metadata.schema.SchemaMetaData; import org.apache.shardingsphere.sql.parser.binder.statement.dml.InsertStatementContext; +import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.AssignmentSegment; import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.InsertValuesSegment; import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment; import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.InsertColumnsSegment; +import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.OnDuplicateKeyColumnsSegment; import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.LiteralExpressionSegment; import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.ParameterMarkerExpressionSegment; import org.apache.shardingsphere.sql.parser.sql.segment.generic.table.SimpleTableSegment; @@ -32,6 +34,7 @@ import org.junit.Test; import java.util.Arrays; import static org.hamcrest.CoreMatchers.is; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThat; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -62,7 +65,7 @@ public final class InsertStatementContextTest { } @Test - public void assertGetGroupedParameters() { + public void assertGetGroupedParametersWithoutOnDuplicateParameter() { SchemaMetaData schemaMetaData = mock(SchemaMetaData.class); when(schemaMetaData.getAllColumnNames("tbl")).thenReturn(Arrays.asList("id", "name", "status")); InsertStatement insertStatement = new InsertStatement(); @@ -70,6 +73,22 @@ public final class InsertStatementContextTest { setUpInsertValues(insertStatement); InsertStatementContext actual = new InsertStatementContext(schemaMetaData, Arrays.asList(1, "Tom", 2, "Jerry"), insertStatement); assertThat(actual.getGroupedParameters().size(), is(2)); + assertNull(actual.getOnDuplicateKeyUpdateValueContext()); + assertThat(actual.getOnDuplicateKeyUpdateParameters().size(), is(0)); + } + + @Test + public void assertGetGroupedParametersWithOnDuplicateParameters() { + SchemaMetaData schemaMetaData = mock(SchemaMetaData.class); + when(schemaMetaData.getAllColumnNames("tbl")).thenReturn(Arrays.asList("id", "name", "status")); + InsertStatement insertStatement = new InsertStatement(); + insertStatement.setTable(new SimpleTableSegment(0, 0, new IdentifierValue("tbl"))); + setUpInsertValues(insertStatement); + setUpOnDuplicateValues(insertStatement); + InsertStatementContext actual = new InsertStatementContext(schemaMetaData, Arrays.asList(1, "Tom", 2, "Jerry", "onDuplicateKeyUpdateColumnValue"), insertStatement); + assertThat(actual.getGroupedParameters().size(), is(2)); + assertThat(actual.getOnDuplicateKeyUpdateValueContext().getColumns().size(), is(2)); + assertThat(actual.getOnDuplicateKeyUpdateParameters().size(), is(1)); } private void setUpInsertValues(final InsertStatement insertStatement) { @@ -79,6 +98,21 @@ public final class InsertStatementContextTest { new ParameterMarkerExpressionSegment(0, 0, 3), new ParameterMarkerExpressionSegment(0, 0, 4), new LiteralExpressionSegment(0, 0, "init")))); } + private void setUpOnDuplicateValues(final InsertStatement insertStatement) { + AssignmentSegment parameterMarkerExpressionAssignment = new AssignmentSegment(0, 0, + new ColumnSegment(0, 0, new IdentifierValue("on_duplicate_key_update_column_1")), + new ParameterMarkerExpressionSegment(0, 0, 4) + ); + AssignmentSegment literalExpressionAssignment = new AssignmentSegment(0, 0, + new ColumnSegment(0, 0, new IdentifierValue("on_duplicate_key_update_column_2")), + new LiteralExpressionSegment(0, 0, 5) + ); + OnDuplicateKeyColumnsSegment onDuplicateKeyColumnsSegment = new OnDuplicateKeyColumnsSegment(0, 0, Arrays.asList( + parameterMarkerExpressionAssignment, literalExpressionAssignment + )); + insertStatement.setOnDuplicateKeyColumns(onDuplicateKeyColumnsSegment); + } + private void assertInsertStatementContext(final InsertStatementContext actual) { assertThat(actual.getColumnNames(), is(Arrays.asList("id", "name", "status"))); assertThat(actual.getInsertValueContexts().size(), is(2)); diff --git a/shardingsphere-sql-parser/shardingsphere-sql-parser-dialect/shardingsphere-sql-parser-mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/impl/MySQLDMLVisitor.java b/shardingsphere-sql-parser/shardingsphere-sql-parser-dialect/shardingsphere-sql-parser-mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/impl/MySQLDMLVisitor.java index b3e99492f0..7bca29671e 100644 --- a/shardingsphere-sql-parser/shardingsphere-sql-parser-dialect/shardingsphere-sql-parser-mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/impl/MySQLDMLVisitor.java +++ b/shardingsphere-sql-parser/shardingsphere-sql-parser-dialect/shardingsphere-sql-parser-mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/impl/MySQLDMLVisitor.java @@ -198,6 +198,7 @@ public final class MySQLDMLVisitor extends MySQLVisitor implements DMLVisitor { Collection columns = new LinkedList<>(); for (AssignmentContext each : ctx.assignment()) { columns.add((AssignmentSegment) visit(each)); + visit(each.assignmentValue()); } return new OnDuplicateKeyColumnsSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), columns); } diff --git a/shardingsphere-underlying/shardingsphere-rewrite/shardingsphere-rewrite-engine/src/main/java/org/apache/shardingsphere/underlying/rewrite/context/SQLRewriteContext.java b/shardingsphere-underlying/shardingsphere-rewrite/shardingsphere-rewrite-engine/src/main/java/org/apache/shardingsphere/underlying/rewrite/context/SQLRewriteContext.java index 6b358a6383..815336c254 100644 --- a/shardingsphere-underlying/shardingsphere-rewrite/shardingsphere-rewrite-engine/src/main/java/org/apache/shardingsphere/underlying/rewrite/context/SQLRewriteContext.java +++ b/shardingsphere-underlying/shardingsphere-rewrite/shardingsphere-rewrite-engine/src/main/java/org/apache/shardingsphere/underlying/rewrite/context/SQLRewriteContext.java @@ -62,12 +62,13 @@ public final class SQLRewriteContext { this.parameters = parameters; addSQLTokenGenerators(new DefaultTokenGeneratorBuilder().getSQLTokenGenerators()); parameterBuilder = sqlStatementContext instanceof InsertStatementContext - ? new GroupedParameterBuilder(((InsertStatementContext) sqlStatementContext).getGroupedParameters()) : new StandardParameterBuilder(parameters); + ? new GroupedParameterBuilder(((InsertStatementContext) sqlStatementContext).getGroupedParameters(), ((InsertStatementContext) sqlStatementContext).getOnDuplicateKeyUpdateParameters()) + : new StandardParameterBuilder(parameters); } /** * Add SQL token generators. - * + * * @param sqlTokenGenerators SQL token generators */ public void addSQLTokenGenerators(final Collection sqlTokenGenerators) { diff --git a/shardingsphere-underlying/shardingsphere-rewrite/shardingsphere-rewrite-engine/src/main/java/org/apache/shardingsphere/underlying/rewrite/engine/GenericSQLRewriteEngine.java b/shardingsphere-underlying/shardingsphere-rewrite/shardingsphere-rewrite-engine/src/main/java/org/apache/shardingsphere/underlying/rewrite/engine/GenericSQLRewriteEngine.java index dcd9b85f6c..25666073d3 100644 --- a/shardingsphere-underlying/shardingsphere-rewrite/shardingsphere-rewrite-engine/src/main/java/org/apache/shardingsphere/underlying/rewrite/engine/GenericSQLRewriteEngine.java +++ b/shardingsphere-underlying/shardingsphere-rewrite/shardingsphere-rewrite-engine/src/main/java/org/apache/shardingsphere/underlying/rewrite/engine/GenericSQLRewriteEngine.java @@ -20,8 +20,14 @@ package org.apache.shardingsphere.underlying.rewrite.engine; import org.apache.shardingsphere.underlying.rewrite.context.SQLRewriteContext; import org.apache.shardingsphere.underlying.rewrite.engine.result.GenericSQLRewriteResult; import org.apache.shardingsphere.underlying.rewrite.engine.result.SQLRewriteUnit; +import org.apache.shardingsphere.underlying.rewrite.parameter.builder.ParameterBuilder; +import org.apache.shardingsphere.underlying.rewrite.parameter.builder.impl.GroupedParameterBuilder; +import org.apache.shardingsphere.underlying.rewrite.parameter.builder.impl.StandardParameterBuilder; import org.apache.shardingsphere.underlying.rewrite.sql.impl.DefaultSQLBuilder; +import java.util.LinkedList; +import java.util.List; + /** * Generic SQL rewrite engine. */ @@ -34,6 +40,22 @@ public final class GenericSQLRewriteEngine { * @return SQL rewrite result */ public GenericSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext) { - return new GenericSQLRewriteResult(new SQLRewriteUnit(new DefaultSQLBuilder(sqlRewriteContext).toSQL(), sqlRewriteContext.getParameterBuilder().getParameters())); + return new GenericSQLRewriteResult(new SQLRewriteUnit(new DefaultSQLBuilder(sqlRewriteContext).toSQL(), getParameters(sqlRewriteContext.getParameterBuilder()))); + } + + private List getParameters(final ParameterBuilder parameterBuilder) { + if (parameterBuilder instanceof StandardParameterBuilder) { + return parameterBuilder.getParameters(); + } + + List onDuplicateKeyUpdateParameters = ((GroupedParameterBuilder) parameterBuilder).getOnDuplicateKeyUpdateParametersBuilder().getParameters(); + if (onDuplicateKeyUpdateParameters.isEmpty()) { + return parameterBuilder.getParameters(); + } + + List result = new LinkedList<>(); + result.addAll(parameterBuilder.getParameters()); + result.addAll(onDuplicateKeyUpdateParameters); + return result; } } diff --git a/shardingsphere-underlying/shardingsphere-rewrite/shardingsphere-rewrite-engine/src/main/java/org/apache/shardingsphere/underlying/rewrite/engine/RouteSQLRewriteEngine.java b/shardingsphere-underlying/shardingsphere-rewrite/shardingsphere-rewrite-engine/src/main/java/org/apache/shardingsphere/underlying/rewrite/engine/RouteSQLRewriteEngine.java index 0649916a08..871bd753cb 100644 --- a/shardingsphere-underlying/shardingsphere-rewrite/shardingsphere-rewrite-engine/src/main/java/org/apache/shardingsphere/underlying/rewrite/engine/RouteSQLRewriteEngine.java +++ b/shardingsphere-underlying/shardingsphere-rewrite/shardingsphere-rewrite-engine/src/main/java/org/apache/shardingsphere/underlying/rewrite/engine/RouteSQLRewriteEngine.java @@ -55,9 +55,22 @@ public final class RouteSQLRewriteEngine { } private List getParameters(final ParameterBuilder parameterBuilder, final RouteResult routeResult, final RouteUnit routeUnit) { - if (parameterBuilder instanceof StandardParameterBuilder || routeResult.getOriginalDataNodes().isEmpty() || parameterBuilder.getParameters().isEmpty()) { + if (parameterBuilder instanceof StandardParameterBuilder) { return parameterBuilder.getParameters(); } + + if (routeResult.getOriginalDataNodes().isEmpty()) { + List onDuplicateKeyUpdateParameters = ((GroupedParameterBuilder) parameterBuilder).getOnDuplicateKeyUpdateParametersBuilder().getParameters(); + if (onDuplicateKeyUpdateParameters.isEmpty()) { + return parameterBuilder.getParameters(); + } + + List result = new LinkedList<>(); + result.addAll(parameterBuilder.getParameters()); + result.addAll(onDuplicateKeyUpdateParameters); + return result; + } + List result = new LinkedList<>(); int count = 0; for (Collection each : routeResult.getOriginalDataNodes()) { @@ -66,6 +79,7 @@ public final class RouteSQLRewriteEngine { } count++; } + result.addAll(((GroupedParameterBuilder) parameterBuilder).getOnDuplicateKeyUpdateParametersBuilder().getParameters()); return result; } diff --git a/shardingsphere-underlying/shardingsphere-rewrite/shardingsphere-rewrite-engine/src/main/java/org/apache/shardingsphere/underlying/rewrite/parameter/builder/impl/GroupedParameterBuilder.java b/shardingsphere-underlying/shardingsphere-rewrite/shardingsphere-rewrite-engine/src/main/java/org/apache/shardingsphere/underlying/rewrite/parameter/builder/impl/GroupedParameterBuilder.java index 711f5f0b7c..c285f0a992 100644 --- a/shardingsphere-underlying/shardingsphere-rewrite/shardingsphere-rewrite-engine/src/main/java/org/apache/shardingsphere/underlying/rewrite/parameter/builder/impl/GroupedParameterBuilder.java +++ b/shardingsphere-underlying/shardingsphere-rewrite/shardingsphere-rewrite-engine/src/main/java/org/apache/shardingsphere/underlying/rewrite/parameter/builder/impl/GroupedParameterBuilder.java @@ -35,16 +35,18 @@ public final class GroupedParameterBuilder implements ParameterBuilder { private final List parameterBuilders; @Getter - private final List onDuplicateKeyUpdateAddedParameters = new LinkedList<>(); + private final StandardParameterBuilder onDuplicateKeyUpdateParametersBuilder; @Setter private String derivedColumnName; - public GroupedParameterBuilder(final List> groupedParameters) { + public GroupedParameterBuilder(final List> groupedParameters, final List onDuplicateKeyUpdateParameters) { parameterBuilders = new ArrayList<>(groupedParameters.size()); for (List each : groupedParameters) { parameterBuilders.add(new StandardParameterBuilder(each)); } + + onDuplicateKeyUpdateParametersBuilder = new StandardParameterBuilder(onDuplicateKeyUpdateParameters); } @Override @@ -53,9 +55,6 @@ public final class GroupedParameterBuilder implements ParameterBuilder { for (int i = 0; i < parameterBuilders.size(); i++) { result.addAll(getParameters(i)); } - if (!onDuplicateKeyUpdateAddedParameters.isEmpty()) { - result.addAll(onDuplicateKeyUpdateAddedParameters); - } return result; } diff --git a/shardingsphere-underlying/shardingsphere-rewrite/shardingsphere-rewrite-engine/src/test/java/org/apache/shardingsphere/underlying/rewrite/impl/GroupedParameterBuilderTest.java b/shardingsphere-underlying/shardingsphere-rewrite/shardingsphere-rewrite-engine/src/test/java/org/apache/shardingsphere/underlying/rewrite/impl/GroupedParameterBuilderTest.java index 0c83b4d772..96f80f8c97 100644 --- a/shardingsphere-underlying/shardingsphere-rewrite/shardingsphere-rewrite-engine/src/test/java/org/apache/shardingsphere/underlying/rewrite/impl/GroupedParameterBuilderTest.java +++ b/shardingsphere-underlying/shardingsphere-rewrite/shardingsphere-rewrite-engine/src/test/java/org/apache/shardingsphere/underlying/rewrite/impl/GroupedParameterBuilderTest.java @@ -20,9 +20,11 @@ package org.apache.shardingsphere.underlying.rewrite.impl; import org.apache.shardingsphere.underlying.rewrite.parameter.builder.impl.GroupedParameterBuilder; import org.junit.Test; +import java.util.ArrayList; import java.util.Arrays; import java.util.LinkedList; import java.util.List; +import java.util.Optional; import static org.hamcrest.CoreMatchers.is; import static org.junit.Assert.assertThat; @@ -31,10 +33,40 @@ public final class GroupedParameterBuilderTest { @Test public void assertGetParameters() { - GroupedParameterBuilder actual = new GroupedParameterBuilder(createGroupedParameters()); + GroupedParameterBuilder actual = new GroupedParameterBuilder(createGroupedParameters(), new ArrayList<>()); assertThat(actual.getParameters(), is(Arrays.asList(3, 4, 5, 6))); } - + + @Test + public void assertGetParametersWithOnDuplicateKeyParameters() { + GroupedParameterBuilder actual = new GroupedParameterBuilder(createGroupedParameters(), createOnDuplicateKeyUpdateParameters()); + assertThat(actual.getParameters(), is(Arrays.asList(3, 4, 5, 6))); + assertThat(actual.getOnDuplicateKeyUpdateParametersBuilder().getParameters(), is(Arrays.asList(7, 8))); + } + + @Test + public void assertGetOnDuplicateKeyParametersWithModify() { + GroupedParameterBuilder actual = new GroupedParameterBuilder(new LinkedList<>(), createOnDuplicateKeyUpdateParameters()); + actual.getOnDuplicateKeyUpdateParametersBuilder().addReplacedParameters(0, 77); + actual.getOnDuplicateKeyUpdateParametersBuilder().addReplacedParameters(1, 88); + actual.getOnDuplicateKeyUpdateParametersBuilder().addAddedParameters(0, Arrays.asList(66, -1)); + actual.getOnDuplicateKeyUpdateParametersBuilder().addAddedParameters(2, Arrays.asList(99, 110)); + actual.getOnDuplicateKeyUpdateParametersBuilder().addRemovedParameters(1); + assertThat(actual.getOnDuplicateKeyUpdateParametersBuilder().getParameters(), is(Arrays.asList(66, 77, 88, 99, 110))); + } + + @Test + public void assertGetDerivedColumnName() { + GroupedParameterBuilder actual = new GroupedParameterBuilder(createGroupedParameters(), createOnDuplicateKeyUpdateParameters()); + String derivedColumnName = "derivedColumnName"; + actual.setDerivedColumnName(derivedColumnName); + assertThat(actual.getDerivedColumnName(), is(Optional.of(derivedColumnName))); + } + + private List createOnDuplicateKeyUpdateParameters() { + return new LinkedList<>(Arrays.asList(7, 8)); + } + private List> createGroupedParameters() { List> result = new LinkedList<>(); result.add(Arrays.asList(3, 4)); -- GitLab