未验证 提交 7a2b03c4 编写于 作者: A avalon5666 提交者: GitHub

Use insert on duplicate key update in mysql insert (#8004)

* Refactor AbstractSQLBuilder & AbstractJDBCImporter

* Move extract sql paramters code from AbstractJDBCImporter to AbstractSQLBuilder

* Refactor AbstractSQLBuilder

* Refactor AbstractJDBCImporter

* Make columns of DataRecord private

* Use insert on duplicate key update in mysql insert.
上级 57692e64
......@@ -25,11 +25,9 @@ import org.apache.shardingsphere.scaling.core.datasource.DataSourceManager;
import org.apache.shardingsphere.scaling.core.exception.SyncTaskExecuteException;
import org.apache.shardingsphere.scaling.core.execute.executor.AbstractShardingScalingExecutor;
import org.apache.shardingsphere.scaling.core.execute.executor.channel.Channel;
import org.apache.shardingsphere.scaling.core.execute.executor.record.Column;
import org.apache.shardingsphere.scaling.core.execute.executor.record.DataRecord;
import org.apache.shardingsphere.scaling.core.execute.executor.record.FinishedRecord;
import org.apache.shardingsphere.scaling.core.execute.executor.record.Record;
import org.apache.shardingsphere.scaling.core.execute.executor.record.RecordUtil;
import org.apache.shardingsphere.scaling.core.job.position.IncrementalPosition;
import javax.sql.DataSource;
......@@ -37,9 +35,10 @@ import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.SQLIntegrityConstraintViolationException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* Abstract JDBC importer implementation.
......@@ -59,7 +58,7 @@ public abstract class AbstractJDBCImporter extends AbstractShardingScalingExecut
protected AbstractJDBCImporter(final ImporterConfiguration importerConfig, final DataSourceManager dataSourceManager) {
this.importerConfig = importerConfig;
this.dataSourceManager = dataSourceManager;
sqlBuilder = createSQLBuilder();
sqlBuilder = createSQLBuilder(importerConfig.getShardingColumnsMap());
}
/**
......@@ -67,7 +66,7 @@ public abstract class AbstractJDBCImporter extends AbstractShardingScalingExecut
*
* @return SQL builder
*/
protected abstract AbstractSQLBuilder createSQLBuilder();
protected abstract AbstractSQLBuilder createSQLBuilder(Map<String, Set<String>> shardingColumnsMap);
@Override
public final void start() {
......@@ -141,37 +140,25 @@ public abstract class AbstractJDBCImporter extends AbstractShardingScalingExecut
}
private void executeInsert(final Connection connection, final DataRecord record) throws SQLException {
String insertSql = sqlBuilder.buildInsertSQL(record);
PreparedStatement ps = connection.prepareStatement(insertSql);
ps.setQueryTimeout(30);
try {
for (int i = 0; i < record.getColumnCount(); i++) {
ps.setObject(i + 1, record.getColumn(i).getValue());
}
ps.execute();
executeSQL(connection, record, sqlBuilder.buildInsertSQL(record));
} catch (final SQLIntegrityConstraintViolationException ignored) {
}
}
private void executeUpdate(final Connection connection, final DataRecord record) throws SQLException {
List<Column> conditionColumns = RecordUtil.extractConditionColumns(record, importerConfig.getShardingColumnsMap().get(record.getTableName()));
List<Column> values = new ArrayList<>();
values.addAll(RecordUtil.extractUpdatedColumns(record));
values.addAll(conditionColumns);
String updateSql = sqlBuilder.buildUpdateSQL(record, conditionColumns);
PreparedStatement ps = connection.prepareStatement(updateSql);
for (int i = 0; i < values.size(); i++) {
ps.setObject(i + 1, values.get(i).getValue());
}
ps.execute();
executeSQL(connection, record, sqlBuilder.buildUpdateSQL(record));
}
private void executeDelete(final Connection connection, final DataRecord record) throws SQLException {
List<Column> conditionColumns = RecordUtil.extractConditionColumns(record, importerConfig.getShardingColumnsMap().get(record.getTableName()));
String deleteSql = sqlBuilder.buildDeleteSQL(record, conditionColumns);
PreparedStatement ps = connection.prepareStatement(deleteSql);
for (int i = 0; i < conditionColumns.size(); i++) {
ps.setObject(i + 1, conditionColumns.get(i).getValue());
executeSQL(connection, record, sqlBuilder.buildDeleteSQL(record));
}
private void executeSQL(final Connection connection, final DataRecord record, final PreparedSQL preparedSQL) throws SQLException {
PreparedStatement ps = connection.prepareStatement(preparedSQL.getSql());
for (int i = 0; i < preparedSQL.getValuesIndex().size(); i++) {
int columnIndex = preparedSQL.getValuesIndex().get(i);
ps.setObject(i + 1, record.getColumn(columnIndex).getValue());
}
ps.execute();
}
......
......@@ -17,18 +17,21 @@
package org.apache.shardingsphere.scaling.core.execute.executor.importer;
import com.google.common.collect.Collections2;
import org.apache.shardingsphere.scaling.core.execute.executor.record.Column;
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.scaling.core.execute.executor.record.DataRecord;
import org.apache.shardingsphere.scaling.core.execute.executor.record.RecordUtil;
import java.util.Collection;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
/**
* Abstract SQL builder.
*/
@RequiredArgsConstructor
public abstract class AbstractSQLBuilder {
private static final String INSERT_SQL_CACHE_KEY_PREFIX = "INSERT_";
......@@ -37,7 +40,9 @@ public abstract class AbstractSQLBuilder {
private static final String DELETE_SQL_CACHE_KEY_PREFIX = "DELETE_";
private final ConcurrentMap<String, String> sqlCacheMap = new ConcurrentHashMap<>();
private final Map<String, Set<String>> shardingColumnsMap;
private final ConcurrentMap<String, PreparedSQL> sqlCacheMap = new ConcurrentHashMap<>();
/**
* Get left identifier quote string.
......@@ -67,79 +72,90 @@ public abstract class AbstractSQLBuilder {
* Build insert SQL.
*
* @param dataRecord data record
* @return insert SQL
* @return insert prepared SQL
*/
public String buildInsertSQL(final DataRecord dataRecord) {
public PreparedSQL buildInsertSQL(final DataRecord dataRecord) {
String sqlCacheKey = INSERT_SQL_CACHE_KEY_PREFIX + dataRecord.getTableName();
if (!sqlCacheMap.containsKey(sqlCacheKey)) {
sqlCacheMap.put(sqlCacheKey, buildInsertSQLInternal(dataRecord.getTableName(), dataRecord.getColumns()));
sqlCacheMap.put(sqlCacheKey, buildInsertSQLInternal(dataRecord));
}
return sqlCacheMap.get(sqlCacheKey);
}
private String buildInsertSQLInternal(final String tableName, final List<Column> columns) {
protected PreparedSQL buildInsertSQLInternal(final DataRecord dataRecord) {
StringBuilder columnsLiteral = new StringBuilder();
StringBuilder holder = new StringBuilder();
for (Column each : columns) {
columnsLiteral.append(String.format("%s,", quote(each.getName())));
List<Integer> valuesIndex = new ArrayList<>();
for (int i = 0; i < dataRecord.getColumnCount(); i++) {
columnsLiteral.append(String.format("%s,", quote(dataRecord.getColumn(i).getName())));
holder.append("?,");
valuesIndex.add(i);
}
columnsLiteral.setLength(columnsLiteral.length() - 1);
holder.setLength(holder.length() - 1);
return String.format("INSERT INTO %s(%s) VALUES(%s)", quote(tableName), columnsLiteral, holder);
return new PreparedSQL(
String.format("INSERT INTO %s(%s) VALUES(%s)", quote(dataRecord.getTableName()), columnsLiteral, holder),
valuesIndex);
}
/**
* Build update SQL.
*
* @param dataRecord data record
* @param conditionColumns condition columns
* @return update SQL
* @return update prepared SQL
*/
public String buildUpdateSQL(final DataRecord dataRecord, final Collection<Column> conditionColumns) {
public PreparedSQL buildUpdateSQL(final DataRecord dataRecord) {
String sqlCacheKey = UPDATE_SQL_CACHE_KEY_PREFIX + dataRecord.getTableName();
if (!sqlCacheMap.containsKey(sqlCacheKey)) {
sqlCacheMap.put(sqlCacheKey, buildUpdateSQLInternal(dataRecord.getTableName(), conditionColumns));
sqlCacheMap.put(sqlCacheKey, buildUpdateSQLInternal(dataRecord));
}
StringBuilder updatedColumnString = new StringBuilder();
for (Column each : extractUpdatedColumns(dataRecord.getColumns())) {
updatedColumnString.append(String.format("%s = ?,", quote(each.getName())));
List<Integer> valuesIndex = new ArrayList<>();
for (Integer each : RecordUtil.extractUpdatedColumns(dataRecord)) {
updatedColumnString.append(String.format("%s = ?,", quote(dataRecord.getColumn(each).getName())));
valuesIndex.add(each);
}
updatedColumnString.setLength(updatedColumnString.length() - 1);
return String.format(sqlCacheMap.get(sqlCacheKey), updatedColumnString);
}
private String buildUpdateSQLInternal(final String tableName, final Collection<Column> conditionColumns) {
return String.format("UPDATE %s SET %%s WHERE %s", quote(tableName), buildWhereSQL(conditionColumns));
PreparedSQL preparedSQL = sqlCacheMap.get(sqlCacheKey);
valuesIndex.addAll(preparedSQL.getValuesIndex());
return new PreparedSQL(
String.format(preparedSQL.getSql(), updatedColumnString),
valuesIndex);
}
private Collection<Column> extractUpdatedColumns(final Collection<Column> columns) {
return Collections2.filter(columns, Column::isUpdated);
private PreparedSQL buildUpdateSQLInternal(final DataRecord dataRecord) {
List<Integer> valuesIndex = new ArrayList<>();
return new PreparedSQL(
String.format("UPDATE %s SET %%s WHERE %s", quote(dataRecord.getTableName()), buildWhereSQL(dataRecord, valuesIndex)),
valuesIndex);
}
/**
* Build delete SQL.
*
* @param dataRecord data record
* @param conditionColumns condition columns
* @return delete SQL
* @return delete prepared SQL
*/
public String buildDeleteSQL(final DataRecord dataRecord, final Collection<Column> conditionColumns) {
public PreparedSQL buildDeleteSQL(final DataRecord dataRecord) {
String sqlCacheKey = DELETE_SQL_CACHE_KEY_PREFIX + dataRecord.getTableName();
if (!sqlCacheMap.containsKey(sqlCacheKey)) {
sqlCacheMap.put(sqlCacheKey, buildDeleteSQLInternal(dataRecord.getTableName(), conditionColumns));
sqlCacheMap.put(sqlCacheKey, buildDeleteSQLInternal(dataRecord));
}
return sqlCacheMap.get(sqlCacheKey);
}
private String buildDeleteSQLInternal(final String tableName, final Collection<Column> conditionColumns) {
return String.format("DELETE FROM %s WHERE %s", quote(tableName), buildWhereSQL(conditionColumns));
private PreparedSQL buildDeleteSQLInternal(final DataRecord dataRecord) {
List<Integer> columnsIndex = new ArrayList<>();
return new PreparedSQL(
String.format("DELETE FROM %s WHERE %s", quote(dataRecord.getTableName()), buildWhereSQL(dataRecord, columnsIndex)),
columnsIndex);
}
private String buildWhereSQL(final Collection<Column> conditionColumns) {
private String buildWhereSQL(final DataRecord dataRecord, final List<Integer> valuesIndex) {
StringBuilder where = new StringBuilder();
for (Column each : conditionColumns) {
where.append(String.format("%s = ? and ", quote(each.getName())));
for (Integer each : RecordUtil.extractConditionColumns(dataRecord, shardingColumnsMap.get(dataRecord.getTableName()))) {
where.append(String.format("%s = ? and ", quote(dataRecord.getColumn(each).getName())));
valuesIndex.add(each);
}
where.setLength(where.length() - 5);
return where.toString();
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.shardingsphere.scaling.core.execute.executor.importer;
import lombok.Getter;
import java.util.Collections;
import java.util.List;
/**
* Prepared SQL, include complete sql and complete values index list.
*/
@Getter
public class PreparedSQL {
private final String sql;
private final List<Integer> valuesIndex;
public PreparedSQL(final String sql, final List<Integer> valuesIndex) {
this.sql = sql;
this.valuesIndex = Collections.unmodifiableList(valuesIndex);
}
}
......@@ -31,8 +31,6 @@ import java.util.List;
/**
* Data record.
*/
@Setter
@Getter
@EqualsAndHashCode(of = {"tableName", "primaryKeyValue"}, callSuper = false)
@ToString
public final class DataRecord extends Record {
......@@ -41,8 +39,12 @@ public final class DataRecord extends Record {
private final List<Object> primaryKeyValue = new LinkedList<>();
@Setter
@Getter
private String type;
@Setter
@Getter
private String tableName;
public DataRecord(final Position position, final int columnCount) {
......
......@@ -20,9 +20,10 @@ package org.apache.shardingsphere.scaling.core.execute.executor.record;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
/**
* Record utility.
......@@ -31,51 +32,45 @@ import java.util.Set;
public final class RecordUtil {
/**
* Extract primary columns from data record.
* Extract primary columns index from data record.
*
* @param dataRecord data record
* @return primary columns
* @return primary columns index
*/
public static List<Column> extractPrimaryColumns(final DataRecord dataRecord) {
List<Column> result = new ArrayList<>(dataRecord.getColumns().size());
for (Column each : dataRecord.getColumns()) {
if (each.isPrimaryKey()) {
result.add(each);
}
}
return result;
public static List<Integer> extractPrimaryColumns(final DataRecord dataRecord) {
return IntStream.range(0, dataRecord.getColumnCount())
.filter(each -> dataRecord.getColumn(each).isPrimaryKey())
.mapToObj(each -> each)
.collect(Collectors.toList());
}
/**
* Extract condition columns(include primary and sharding columns) from data record.
* Extract condition columns(include primary and sharding columns) index from data record.
*
* @param dataRecord data record
* @param shardingColumns sharding columns
* @return condition columns
* @return condition columns index
*/
public static List<Column> extractConditionColumns(final DataRecord dataRecord, final Set<String> shardingColumns) {
List<Column> result = new ArrayList<>(dataRecord.getColumns().size());
for (Column each : dataRecord.getColumns()) {
if (each.isPrimaryKey() || shardingColumns.contains(each.getName())) {
result.add(each);
}
}
return result;
public static List<Integer> extractConditionColumns(final DataRecord dataRecord, final Set<String> shardingColumns) {
return IntStream.range(0, dataRecord.getColumnCount())
.filter(each -> {
Column column = dataRecord.getColumn(each);
return column.isPrimaryKey() || shardingColumns.contains(column.getName());
})
.mapToObj(each -> each)
.collect(Collectors.toList());
}
/**
* Extract updated columns from data record.
*
* @param dataRecord data record
* @return updated columns
* @return updated columns index
*/
public static List<Column> extractUpdatedColumns(final DataRecord dataRecord) {
List<Column> result = new ArrayList<>(dataRecord.getColumns().size());
for (Column each : dataRecord.getColumns()) {
if (each.isUpdated()) {
result.add(each);
}
}
return result;
public static List<Integer> extractUpdatedColumns(final DataRecord dataRecord) {
return IntStream.range(0, dataRecord.getColumnCount())
.filter(each -> dataRecord.getColumn(each).isUpdated())
.mapToObj(each -> each)
.collect(Collectors.toList());
}
}
......@@ -17,17 +17,17 @@
package org.apache.shardingsphere.scaling.core.execute.executor.importer;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.shardingsphere.scaling.core.config.ScalingDataSourceConfiguration;
import org.apache.shardingsphere.scaling.core.config.ImporterConfiguration;
import org.apache.shardingsphere.scaling.core.config.ScalingDataSourceConfiguration;
import org.apache.shardingsphere.scaling.core.datasource.DataSourceManager;
import org.apache.shardingsphere.scaling.core.execute.executor.channel.Channel;
import org.apache.shardingsphere.scaling.core.execute.executor.record.Column;
import org.apache.shardingsphere.scaling.core.execute.executor.record.DataRecord;
import org.apache.shardingsphere.scaling.core.execute.executor.record.FinishedRecord;
import org.apache.shardingsphere.scaling.core.execute.executor.record.Record;
import org.apache.shardingsphere.scaling.core.execute.executor.record.RecordUtil;
import org.apache.shardingsphere.scaling.core.job.position.NopPosition;
import org.junit.Before;
import org.junit.Test;
......@@ -39,7 +39,6 @@ import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
......@@ -87,7 +86,7 @@ public final class AbstractJDBCImporterTest {
jdbcImporter = new AbstractJDBCImporter(getImporterConfiguration(), dataSourceManager) {
@Override
protected AbstractSQLBuilder createSQLBuilder() {
protected AbstractSQLBuilder createSQLBuilder(final Map<String, Set<String>> shardingColumnsMap) {
return sqlBuilder;
}
};
......@@ -97,9 +96,9 @@ public final class AbstractJDBCImporterTest {
}
@Test
public void assertWriteInsertDataRecord() throws SQLException {
public void assertInsertDataRecord() throws SQLException {
DataRecord insertRecord = getDataRecord("INSERT");
when(sqlBuilder.buildInsertSQL(insertRecord)).thenReturn(INSERT_SQL);
when(sqlBuilder.buildInsertSQL(insertRecord)).thenReturn(new PreparedSQL(INSERT_SQL, Lists.newArrayList(0, 1, 2)));
when(connection.prepareStatement(INSERT_SQL)).thenReturn(preparedStatement);
when(channel.fetchRecords(100, 3)).thenReturn(mockRecords(insertRecord));
jdbcImporter.run();
......@@ -112,7 +111,7 @@ public final class AbstractJDBCImporterTest {
@Test
public void assertDeleteDataRecord() throws SQLException {
DataRecord deleteRecord = getDataRecord("DELETE");
when(sqlBuilder.buildDeleteSQL(deleteRecord, mockConditionColumns(deleteRecord))).thenReturn(DELETE_SQL);
when(sqlBuilder.buildDeleteSQL(deleteRecord)).thenReturn(new PreparedSQL(DELETE_SQL, Lists.newArrayList(0, 1)));
when(connection.prepareStatement(DELETE_SQL)).thenReturn(preparedStatement);
when(channel.fetchRecords(100, 3)).thenReturn(mockRecords(deleteRecord));
jdbcImporter.run();
......@@ -124,7 +123,7 @@ public final class AbstractJDBCImporterTest {
@Test
public void assertUpdateDataRecord() throws SQLException {
DataRecord updateRecord = getDataRecord("UPDATE");
when(sqlBuilder.buildUpdateSQL(updateRecord, mockConditionColumns(updateRecord))).thenReturn(UPDATE_SQL);
when(sqlBuilder.buildUpdateSQL(updateRecord)).thenReturn(new PreparedSQL(UPDATE_SQL, Lists.newArrayList(1, 2, 0, 1)));
when(connection.prepareStatement(UPDATE_SQL)).thenReturn(preparedStatement);
when(channel.fetchRecords(100, 3)).thenReturn(mockRecords(updateRecord));
jdbcImporter.run();
......@@ -135,10 +134,6 @@ public final class AbstractJDBCImporterTest {
verify(preparedStatement).execute();
}
private Collection<Column> mockConditionColumns(final DataRecord dataRecord) {
return RecordUtil.extractConditionColumns(dataRecord, Sets.newHashSet("user"));
}
private List<Record> mockRecords(final DataRecord dataRecord) {
List<Record> result = new LinkedList<>();
result.add(dataRecord);
......
......@@ -20,29 +20,37 @@ package org.apache.shardingsphere.scaling.core.execute.executor.importer;
import com.google.common.collect.Sets;
import org.apache.shardingsphere.scaling.core.execute.executor.record.Column;
import org.apache.shardingsphere.scaling.core.execute.executor.record.DataRecord;
import org.apache.shardingsphere.scaling.core.execute.executor.record.RecordUtil;
import org.apache.shardingsphere.scaling.core.job.position.NopPosition;
import org.hamcrest.Matchers;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import java.util.Collection;
import java.util.Map;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public class AbstractSqlBuilderTest {
@Mock
private Map shardingColumnsMap;
private AbstractSQLBuilder sqlBuilder;
@Before
public void setUp() {
sqlBuilder = new AbstractSQLBuilder() {
sqlBuilder = new AbstractSQLBuilder(shardingColumnsMap) {
@Override
protected String getLeftIdentifierQuoteString() {
return "`";
}
@Override
protected String getRightIdentifierQuoteString() {
return "`";
......@@ -52,38 +60,66 @@ public class AbstractSqlBuilderTest {
@Test
public void assertBuildInsertSQL() {
String actual = sqlBuilder.buildInsertSQL(mockDataRecord("t1"));
assertThat(actual, is("INSERT INTO `t1`(`id`,`sc`,`c1`,`c2`,`c3`) VALUES(?,?,?,?,?)"));
PreparedSQL actual = sqlBuilder.buildInsertSQL(mockDataRecord("t1"));
assertThat(actual.getSql(), is("INSERT INTO `t1`(`id`,`sc`,`c1`,`c2`,`c3`) VALUES(?,?,?,?,?)"));
assertThat(actual.getValuesIndex().toArray(), Matchers.arrayContaining(0, 1, 2, 3, 4));
}
@Test
public void assertBuildUpdateSQLWithPrimaryKey() {
String actual = sqlBuilder.buildUpdateSQL(mockDataRecord("t2"), RecordUtil.extractPrimaryColumns(mockDataRecord("t2")));
assertThat(actual, is("UPDATE `t2` SET `c1` = ?,`c2` = ?,`c3` = ? WHERE `id` = ?"));
when(shardingColumnsMap.get("t2")).thenReturn(Sets.newHashSet());
PreparedSQL actual = sqlBuilder.buildUpdateSQL(mockDataRecord("t2"));
assertThat(actual.getSql(), is("UPDATE `t2` SET `c1` = ?,`c2` = ?,`c3` = ? WHERE `id` = ?"));
assertThat(actual.getValuesIndex().toArray(), Matchers.arrayContaining(2, 3, 4, 0));
}
@Test
public void assertBuildUpdateSQLWithShardingColumns() {
when(shardingColumnsMap.get("t2")).thenReturn(Sets.newHashSet("sc"));
DataRecord dataRecord = mockDataRecord("t2");
String actual = sqlBuilder.buildUpdateSQL(dataRecord, mockConditionColumns(dataRecord));
assertThat(actual, is("UPDATE `t2` SET `c1` = ?,`c2` = ?,`c3` = ? WHERE `id` = ? and `sc` = ?"));
PreparedSQL actual = sqlBuilder.buildUpdateSQL(dataRecord);
assertThat(actual.getSql(), is("UPDATE `t2` SET `c1` = ?,`c2` = ?,`c3` = ? WHERE `id` = ? and `sc` = ?"));
assertThat(actual.getValuesIndex().toArray(), Matchers.arrayContaining(2, 3, 4, 0, 1));
}
@Test
public void assertBuildDeleteSQLWithPrimaryKey() {
String actual = sqlBuilder.buildDeleteSQL(mockDataRecord("t3"), RecordUtil.extractPrimaryColumns(mockDataRecord("t3")));
assertThat(actual, is("DELETE FROM `t3` WHERE `id` = ?"));
public void assertBuildUpdateSQLWithShardingColumnsUseCache() {
when(shardingColumnsMap.get("t2")).thenReturn(Sets.newHashSet("sc"));
DataRecord dataRecord = mockDataRecord("t2");
PreparedSQL actual = sqlBuilder.buildUpdateSQL(dataRecord);
assertThat(actual.getSql(), is("UPDATE `t2` SET `c1` = ?,`c2` = ?,`c3` = ? WHERE `id` = ? and `sc` = ?"));
assertThat(actual.getValuesIndex().toArray(), Matchers.arrayContaining(2, 3, 4, 0, 1));
actual = sqlBuilder.buildUpdateSQL(mockDataRecord2("t2"));
assertThat(actual.getSql(), is("UPDATE `t2` SET `c1` = ?,`c3` = ? WHERE `id` = ? and `sc` = ?"));
assertThat(actual.getValuesIndex().toArray(), Matchers.arrayContaining(2, 4, 0, 1));
}
private DataRecord mockDataRecord2(final String tableName) {
DataRecord result = new DataRecord(new NopPosition(), 4);
result.setTableName(tableName);
result.addColumn(new Column("id", "", false, true));
result.addColumn(new Column("sc", "", false, false));
result.addColumn(new Column("c1", "", true, false));
result.addColumn(new Column("c2", "", false, false));
result.addColumn(new Column("c3", "", true, false));
return result;
}
@Test
public void assertBuildDeleteSQLWithConditionColumns() {
DataRecord dataRecord = mockDataRecord("t3");
String actual = sqlBuilder.buildDeleteSQL(dataRecord, mockConditionColumns(dataRecord));
assertThat(actual, is("DELETE FROM `t3` WHERE `id` = ? and `sc` = ?"));
public void assertBuildDeleteSQLWithPrimaryKey() {
when(shardingColumnsMap.get("t3")).thenReturn(Sets.newHashSet());
PreparedSQL actual = sqlBuilder.buildDeleteSQL(mockDataRecord("t3"));
assertThat(actual.getSql(), is("DELETE FROM `t3` WHERE `id` = ?"));
assertThat(actual.getValuesIndex().toArray(), Matchers.arrayContaining(0));
}
private Collection<Column> mockConditionColumns(final DataRecord dataRecord) {
return RecordUtil.extractConditionColumns(dataRecord, Sets.newHashSet("sc"));
@Test
public void assertBuildDeleteSQLWithShardingColumns() {
when(shardingColumnsMap.get("t3")).thenReturn(Sets.newHashSet("sc"));
DataRecord dataRecord = mockDataRecord("t3");
PreparedSQL actual = sqlBuilder.buildDeleteSQL(dataRecord);
assertThat(actual.getSql(), is("DELETE FROM `t3` WHERE `id` = ? and `sc` = ?"));
assertThat(actual.getValuesIndex().toArray(), Matchers.arrayContaining(0, 1));
}
private DataRecord mockDataRecord(final String tableName) {
......
......@@ -17,6 +17,7 @@
package org.apache.shardingsphere.scaling.core.fixture;
import com.google.common.collect.Maps;
import org.apache.shardingsphere.scaling.core.check.AbstractDataConsistencyChecker;
import org.apache.shardingsphere.scaling.core.check.DataConsistencyCheckResult;
import org.apache.shardingsphere.scaling.core.check.DataConsistencyChecker;
......@@ -44,7 +45,7 @@ public final class FixtureDataConsistencyChecker extends AbstractDataConsistency
@Override
protected AbstractSQLBuilder getSqlBuilder() {
return new AbstractSQLBuilder() {
return new AbstractSQLBuilder(Maps.newHashMap()) {
@Override
protected String getLeftIdentifierQuoteString() {
return "`";
......
......@@ -17,6 +17,7 @@
package org.apache.shardingsphere.scaling.mysql;
import com.google.common.collect.Maps;
import org.apache.shardingsphere.scaling.core.check.AbstractDataConsistencyChecker;
import org.apache.shardingsphere.scaling.core.check.DataConsistencyChecker;
import org.apache.shardingsphere.scaling.core.datasource.DataSourceWrapper;
......@@ -98,6 +99,6 @@ public final class MySQLDataConsistencyChecker extends AbstractDataConsistencyCh
@Override
protected MySQLSQLBuilder getSqlBuilder() {
return new MySQLSQLBuilder();
return new MySQLSQLBuilder(Maps.newHashMap());
}
}
......@@ -22,6 +22,9 @@ import org.apache.shardingsphere.scaling.core.datasource.DataSourceManager;
import org.apache.shardingsphere.scaling.core.execute.executor.importer.AbstractJDBCImporter;
import org.apache.shardingsphere.scaling.core.execute.executor.importer.AbstractSQLBuilder;
import java.util.Map;
import java.util.Set;
/**
* MySQL importer.
*/
......@@ -32,7 +35,7 @@ public final class MySQLImporter extends AbstractJDBCImporter {
}
@Override
protected AbstractSQLBuilder createSQLBuilder() {
return new MySQLSQLBuilder();
protected AbstractSQLBuilder createSQLBuilder(final Map<String, Set<String>> shardingColumnsMap) {
return new MySQLSQLBuilder(shardingColumnsMap);
}
}
......@@ -18,12 +18,24 @@
package org.apache.shardingsphere.scaling.mysql;
import org.apache.shardingsphere.scaling.core.execute.executor.importer.AbstractSQLBuilder;
import org.apache.shardingsphere.scaling.core.execute.executor.importer.PreparedSQL;
import org.apache.shardingsphere.scaling.core.execute.executor.record.Column;
import org.apache.shardingsphere.scaling.core.execute.executor.record.DataRecord;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* MySQL SQL builder.
*/
public final class MySQLSQLBuilder extends AbstractSQLBuilder {
public MySQLSQLBuilder(final Map<String, Set<String>> shardingColumnsMap) {
super(shardingColumnsMap);
}
@Override
public String getLeftIdentifierQuoteString() {
return "`";
......@@ -34,6 +46,22 @@ public final class MySQLSQLBuilder extends AbstractSQLBuilder {
return "`";
}
@Override
protected PreparedSQL buildInsertSQLInternal(final DataRecord dataRecord) {
PreparedSQL preparedSQL = super.buildInsertSQLInternal(dataRecord);
StringBuilder insertSQL = new StringBuilder(preparedSQL.getSql() + " ON DUPLICATE KEY UPDATE ");
List<Integer> valuesIndex = new ArrayList<>(preparedSQL.getValuesIndex());
for (int i = 0; i < dataRecord.getColumnCount(); i++) {
Column column = dataRecord.getColumn(i);
if (!dataRecord.getColumn(i).isPrimaryKey()) {
insertSQL.append(quote(column.getName())).append("=?,");
valuesIndex.add(i);
}
}
insertSQL.setLength(insertSQL.length() - 1);
return new PreparedSQL(insertSQL.toString(), valuesIndex);
}
/**
* Build select sum crc32 SQL.
*
......
......@@ -17,11 +17,14 @@
package org.apache.shardingsphere.scaling.mysql;
import com.google.common.collect.Maps;
import org.apache.shardingsphere.scaling.core.config.ImporterConfiguration;
import org.apache.shardingsphere.scaling.core.datasource.DataSourceManager;
import org.apache.shardingsphere.scaling.core.execute.executor.importer.PreparedSQL;
import org.apache.shardingsphere.scaling.core.execute.executor.record.Column;
import org.apache.shardingsphere.scaling.core.execute.executor.record.DataRecord;
import org.apache.shardingsphere.scaling.mysql.binlog.BinlogPosition;
import org.hamcrest.Matchers;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
......@@ -42,8 +45,9 @@ public final class MySQLImporterTest {
@Test
public void assertCreateSqlBuilder() {
MySQLImporter mySQLImporter = new MySQLImporter(importerConfig, dataSourceManager);
String insertSQL = mySQLImporter.createSQLBuilder().buildInsertSQL(mockDataRecord());
assertThat(insertSQL, is("INSERT INTO `t_order`(`id`,`name`) VALUES(?,?)"));
PreparedSQL insertSQL = mySQLImporter.createSQLBuilder(Maps.newHashMap()).buildInsertSQL(mockDataRecord());
assertThat(insertSQL.getSql(), is("INSERT INTO `t_order`(`id`,`name`) VALUES(?,?) ON DUPLICATE KEY UPDATE `name`=?"));
assertThat(insertSQL.getValuesIndex().toArray(), Matchers.arrayContaining(0, 1, 1));
}
private DataRecord mockDataRecord() {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.shardingsphere.scaling.mysql;
import com.google.common.collect.Maps;
import org.apache.shardingsphere.scaling.core.execute.executor.importer.PreparedSQL;
import org.apache.shardingsphere.scaling.core.execute.executor.record.Column;
import org.apache.shardingsphere.scaling.core.execute.executor.record.DataRecord;
import org.apache.shardingsphere.scaling.mysql.binlog.BinlogPosition;
import org.hamcrest.Matchers;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;
public final class MySQLSqlBuilderTest {
@Test
public void assertBuildInsertSQL() {
PreparedSQL actual = new MySQLSQLBuilder(Maps.newHashMap()).buildInsertSQL(mockDataRecord());
assertThat(actual.getSql(), is("INSERT INTO `t_order`(`id`,`name`,`age`) VALUES(?,?,?) ON DUPLICATE KEY UPDATE `name`=?,`age`=?"));
assertThat(actual.getValuesIndex().toArray(), Matchers.arrayContaining(0, 1, 2, 1, 2));
}
private DataRecord mockDataRecord() {
DataRecord result = new DataRecord(new BinlogPosition("", 1), 2);
result.setTableName("t_order");
result.addColumn(new Column("id", 1, true, true));
result.addColumn(new Column("name", "", true, false));
result.addColumn(new Column("age", 1, true, false));
return result;
}
}
......@@ -17,6 +17,7 @@
package org.apache.shardingsphere.scaling.postgresql;
import com.google.common.collect.Maps;
import org.apache.shardingsphere.scaling.core.check.AbstractDataConsistencyChecker;
import org.apache.shardingsphere.scaling.core.check.DataConsistencyChecker;
import org.apache.shardingsphere.scaling.core.execute.executor.importer.AbstractSQLBuilder;
......@@ -41,6 +42,6 @@ public final class PostgreSQLDataConsistencyChecker extends AbstractDataConsiste
@Override
protected AbstractSQLBuilder getSqlBuilder() {
return new PostgreSQLSQLBuilder();
return new PostgreSQLSQLBuilder(Maps.newHashMap());
}
}
......@@ -22,6 +22,9 @@ import org.apache.shardingsphere.scaling.core.datasource.DataSourceManager;
import org.apache.shardingsphere.scaling.core.execute.executor.importer.AbstractJDBCImporter;
import org.apache.shardingsphere.scaling.core.execute.executor.importer.AbstractSQLBuilder;
import java.util.Map;
import java.util.Set;
/**
* postgreSQL importer.
*/
......@@ -32,8 +35,8 @@ public final class PostgreSQLImporter extends AbstractJDBCImporter {
}
@Override
protected AbstractSQLBuilder createSQLBuilder() {
return new PostgreSQLSQLBuilder();
protected AbstractSQLBuilder createSQLBuilder(final Map<String, Set<String>> shardingColumnsMap) {
return new PostgreSQLSQLBuilder(shardingColumnsMap);
}
}
......@@ -17,16 +17,23 @@
package org.apache.shardingsphere.scaling.postgresql;
import org.apache.shardingsphere.scaling.core.execute.executor.record.Column;
import org.apache.shardingsphere.scaling.core.execute.executor.importer.AbstractSQLBuilder;
import org.apache.shardingsphere.scaling.core.execute.executor.importer.PreparedSQL;
import org.apache.shardingsphere.scaling.core.execute.executor.record.DataRecord;
import org.apache.shardingsphere.scaling.core.execute.executor.record.RecordUtil;
import org.apache.shardingsphere.scaling.core.execute.executor.importer.AbstractSQLBuilder;
import java.util.Map;
import java.util.Set;
/**
* PostgreSQL SQL builder.
*/
public final class PostgreSQLSQLBuilder extends AbstractSQLBuilder {
public PostgreSQLSQLBuilder(final Map<String, Set<String>> shardingColumnsMap) {
super(shardingColumnsMap);
}
@Override
public String getLeftIdentifierQuoteString() {
return "\"";
......@@ -38,14 +45,15 @@ public final class PostgreSQLSQLBuilder extends AbstractSQLBuilder {
}
@Override
public String buildInsertSQL(final DataRecord dataRecord) {
return super.buildInsertSQL(dataRecord) + buildConflictSQL(dataRecord);
public PreparedSQL buildInsertSQL(final DataRecord dataRecord) {
PreparedSQL preparedSQL = super.buildInsertSQL(dataRecord);
return new PreparedSQL(preparedSQL.getSql() + buildConflictSQL(dataRecord), preparedSQL.getValuesIndex());
}
private String buildConflictSQL(final DataRecord dataRecord) {
StringBuilder result = new StringBuilder(" ON CONFLICT (");
for (Column each : RecordUtil.extractPrimaryColumns(dataRecord)) {
result.append(each.getName()).append(",");
for (Integer each : RecordUtil.extractPrimaryColumns(dataRecord)) {
result.append(dataRecord.getColumn(each).getName()).append(",");
}
result.setLength(result.length() - 1);
result.append(") DO NOTHING");
......
......@@ -17,11 +17,14 @@
package org.apache.shardingsphere.scaling.postgresql;
import com.google.common.collect.Maps;
import org.apache.shardingsphere.scaling.core.config.ImporterConfiguration;
import org.apache.shardingsphere.scaling.core.datasource.DataSourceManager;
import org.apache.shardingsphere.scaling.core.execute.executor.importer.PreparedSQL;
import org.apache.shardingsphere.scaling.core.execute.executor.record.Column;
import org.apache.shardingsphere.scaling.core.execute.executor.record.DataRecord;
import org.apache.shardingsphere.scaling.postgresql.wal.WalPosition;
import org.hamcrest.Matchers;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
......@@ -43,8 +46,9 @@ public final class PostgreSQLImporterTest {
@Test
public void assertCreateSQLBuilder() {
PostgreSQLImporter postgreSQLImporter = new PostgreSQLImporter(importerConfig, dataSourceManager);
String insertSQL = postgreSQLImporter.createSQLBuilder().buildInsertSQL(mockDataRecord());
assertThat(insertSQL, is("INSERT INTO \"t_order\"(\"id\",\"name\") VALUES(?,?) ON CONFLICT (id) DO NOTHING"));
PreparedSQL insertSQL = postgreSQLImporter.createSQLBuilder(Maps.newHashMap()).buildInsertSQL(mockDataRecord());
assertThat(insertSQL.getSql(), is("INSERT INTO \"t_order\"(\"id\",\"name\") VALUES(?,?) ON CONFLICT (id) DO NOTHING"));
assertThat(insertSQL.getValuesIndex().toArray(), Matchers.arrayContaining(0, 1));
}
private DataRecord mockDataRecord() {
......
......@@ -17,9 +17,12 @@
package org.apache.shardingsphere.scaling.postgresql;
import com.google.common.collect.Maps;
import org.apache.shardingsphere.scaling.core.execute.executor.importer.PreparedSQL;
import org.apache.shardingsphere.scaling.core.execute.executor.record.Column;
import org.apache.shardingsphere.scaling.core.execute.executor.record.DataRecord;
import org.apache.shardingsphere.scaling.postgresql.wal.WalPosition;
import org.hamcrest.Matchers;
import org.junit.Test;
import org.postgresql.replication.LogSequenceNumber;
......@@ -30,8 +33,9 @@ public final class PostgreSQLSqlBuilderTest {
@Test
public void assertBuildInsertSQL() {
String actual = new PostgreSQLSQLBuilder().buildInsertSQL(mockDataRecord());
assertThat(actual, is("INSERT INTO \"t_order\"(\"id\",\"name\") VALUES(?,?) ON CONFLICT (id) DO NOTHING"));
PreparedSQL actual = new PostgreSQLSQLBuilder(Maps.newHashMap()).buildInsertSQL(mockDataRecord());
assertThat(actual.getSql(), is("INSERT INTO \"t_order\"(\"id\",\"name\") VALUES(?,?) ON CONFLICT (id) DO NOTHING"));
assertThat(actual.getValuesIndex().toArray(), Matchers.arrayContaining(0, 1));
}
private DataRecord mockDataRecord() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册