未验证 提交 57047eca 编写于 作者: D DemonLms 提交者: GitHub

fix(encrypt): Fixed issues#5979 (#5985)

* fix(encrypt): Fixed issues#5979

Fixed data encryption can not properly processing SQL mixed with `?` and
literal. And add some test cases for mixed `?` and literal.

Closes #5979

* pref: Get the parameter index using an existing implementation.

* pref: Optimize instert SQL rewrite logic for mixing letters and parameters: literals insert literals, parameters insert parameters.

* pref: Extract duplicate code as a function.
Co-authored-by: NYour Name <you@example.com>
Co-authored-by: NDemon <demon@liuguopingdeiMac.local>
上级 93181172
......@@ -26,6 +26,8 @@ import org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.GroupedPar
import org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.StandardParameterBuilder;
import org.apache.shardingsphere.sql.parser.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.sql.parser.binder.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
import java.util.ArrayList;
import java.util.Collection;
......@@ -38,12 +40,12 @@ import java.util.Optional;
* Insert value parameter rewriter for encrypt.
*/
public final class EncryptInsertValueParameterRewriter extends EncryptParameterRewriter<InsertStatementContext> {
@Override
protected boolean isNeedRewriteForEncrypt(final SQLStatementContext sqlStatementContext) {
return sqlStatementContext instanceof InsertStatementContext && !(((InsertStatementContext) sqlStatementContext).getSqlStatement()).getSetAssignment().isPresent();
}
@Override
public void rewrite(final ParameterBuilder parameterBuilder, final InsertStatementContext insertStatementContext, final List<Object> parameters) {
String tableName = insertStatementContext.getSqlStatement().getTable().getTableName().getIdentifier().getValue();
......@@ -54,21 +56,26 @@ public final class EncryptInsertValueParameterRewriter extends EncryptParameterR
encryptAlgorithm -> encryptInsertValues((GroupedParameterBuilder) parameterBuilder, insertStatementContext, encryptAlgorithm, tableName, columnName));
}
}
private void encryptInsertValues(final GroupedParameterBuilder parameterBuilder,
final InsertStatementContext insertStatementContext, final EncryptAlgorithm encryptAlgorithm, final String tableName, final String encryptLogicColumnName) {
int columnIndex = getColumnIndex(parameterBuilder, insertStatementContext, encryptLogicColumnName);
int count = 0;
for (List<Object> each : insertStatementContext.getGroupedParameters()) {
int paramterIndex = insertStatementContext.getInsertValueContexts().get(count).getParameterIndex(columnIndex);
if (!each.isEmpty()) {
StandardParameterBuilder standardParameterBuilder = parameterBuilder.getParameterBuilders().get(count);
encryptInsertValue(
encryptAlgorithm, tableName, columnIndex, insertStatementContext.getInsertValueContexts().get(count).getValue(columnIndex), standardParameterBuilder, encryptLogicColumnName);
ExpressionSegment expressionSegment = insertStatementContext.getInsertValueContexts().get(count).getValueExpressions().get(columnIndex);
if (expressionSegment instanceof ParameterMarkerExpressionSegment) {
encryptInsertValue(
encryptAlgorithm, tableName, paramterIndex, insertStatementContext.getInsertValueContexts().get(count).getValue(columnIndex),
standardParameterBuilder, encryptLogicColumnName);
}
}
count++;
}
}
private int getColumnIndex(final GroupedParameterBuilder parameterBuilder, final InsertStatementContext insertStatementContext, final String encryptLogicColumnName) {
List<String> columnNames;
if (parameterBuilder.getDerivedColumnName().isPresent()) {
......@@ -79,13 +86,10 @@ public final class EncryptInsertValueParameterRewriter extends EncryptParameterR
}
return columnNames.indexOf(encryptLogicColumnName);
}
private void encryptInsertValue(final EncryptAlgorithm encryptAlgorithm, final String tableName, final int columnIndex,
private void encryptInsertValue(final EncryptAlgorithm encryptAlgorithm, final String tableName, final int paramterIndex,
final Object originalValue, final StandardParameterBuilder parameterBuilder, final String encryptLogicColumnName) {
// FIXME: can process all part of insert value is ? or literal, can not process mix ? and literal
// For example: values (?, ?), (1, 1) can process
// For example: values (?, 1), (?, 2) can not process
parameterBuilder.addReplacedParameters(columnIndex, encryptAlgorithm.encrypt(originalValue));
parameterBuilder.addReplacedParameters(paramterIndex, encryptAlgorithm.encrypt(originalValue));
Collection<Object> addedParameters = new LinkedList<>();
if (encryptAlgorithm instanceof QueryAssistedEncryptAlgorithm) {
Optional<String> assistedColumnName = getEncryptRule().findAssistedQueryColumn(tableName, encryptLogicColumnName);
......@@ -96,10 +100,10 @@ public final class EncryptInsertValueParameterRewriter extends EncryptParameterR
addedParameters.add(originalValue);
}
if (!addedParameters.isEmpty()) {
if (!parameterBuilder.getAddedIndexAndParameters().containsKey(columnIndex + 1)) {
parameterBuilder.getAddedIndexAndParameters().put(columnIndex + 1, new LinkedList<>());
if (!parameterBuilder.getAddedIndexAndParameters().containsKey(paramterIndex + 1)) {
parameterBuilder.getAddedIndexAndParameters().put(paramterIndex + 1, new LinkedList<>());
}
parameterBuilder.getAddedIndexAndParameters().get(columnIndex + 1).addAll(addedParameters);
parameterBuilder.getAddedIndexAndParameters().get(paramterIndex + 1).addAll(addedParameters);
}
}
}
......@@ -133,7 +133,7 @@ public final class EncryptInsertValuesTokenGenerator extends BaseEncryptSQLToken
private void addPlainColumn(final InsertValue insertValueToken, final int columnIndex,
final String tableName, final String columnName, final InsertValueContext insertValueContext, final Object originalValue) {
if (getEncryptRule().findPlainColumn(tableName, columnName).isPresent()) {
DerivedSimpleExpressionSegment derivedExpressionSegment = insertValueContext.getParameters().isEmpty()
DerivedSimpleExpressionSegment derivedExpressionSegment = isAddLiteralExpressionSegment(insertValueContext, columnIndex)
? new DerivedLiteralExpressionSegment(originalValue) : new DerivedParameterMarkerExpressionSegment(getParameterIndexCount(insertValueToken));
insertValueToken.getValues().add(columnIndex + 1, derivedExpressionSegment);
}
......@@ -142,12 +142,17 @@ public final class EncryptInsertValuesTokenGenerator extends BaseEncryptSQLToken
private void addAssistedQueryColumn(final InsertValue insertValueToken, final EncryptAlgorithm encryptAlgorithm, final int columnIndex,
final String tableName, final String columnName, final InsertValueContext insertValueContext, final Object originalValue) {
if (getEncryptRule().findAssistedQueryColumn(tableName, columnName).isPresent()) {
DerivedSimpleExpressionSegment derivedExpressionSegment = insertValueContext.getParameters().isEmpty()
DerivedSimpleExpressionSegment derivedExpressionSegment = isAddLiteralExpressionSegment(insertValueContext, columnIndex)
? new DerivedLiteralExpressionSegment(((QueryAssistedEncryptAlgorithm) encryptAlgorithm).queryAssistedEncrypt(null == originalValue ? null : originalValue.toString()))
: new DerivedParameterMarkerExpressionSegment(getParameterIndexCount(insertValueToken));
insertValueToken.getValues().add(columnIndex + 1, derivedExpressionSegment);
}
}
private boolean isAddLiteralExpressionSegment(final InsertValueContext insertValueContext, final int columnIndex) {
return insertValueContext.getParameters().isEmpty()
|| insertValueContext.getValueExpressions().get(columnIndex) instanceof LiteralExpressionSegment;
}
private int getParameterIndexCount(final InsertValue insertValueToken) {
int result = 0;
......
......@@ -82,6 +82,17 @@ public final class InsertValueContext {
ExpressionSegment valueExpression = valueExpressions.get(index);
return valueExpression instanceof ParameterMarkerExpressionSegment ? parameters.get(getParameterIndex(valueExpression)) : ((LiteralExpressionSegment) valueExpression).getLiterals();
}
/**
* Get parameter index via column index.
*
* @param index column index
* @return parameter index
*/
public int getParameterIndex(final int index) {
ExpressionSegment valueExpression = valueExpressions.get(index);
return getParameterIndex(valueExpression);
}
private int getParameterIndex(final ExpressionSegment valueExpression) {
int result = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册