未验证 提交 fe56f1d6 编写于 作者: L Liang Zhang 提交者: GitHub

Decouple AuthenticationEngine and BackendConnection (#7038)

* Decouple AuthenticationEngine and BackendConnection

* Decouple AuthenticationEngine and BackendConnection
上级 5ee3490c
......@@ -69,7 +69,7 @@ public final class BackendConnection implements JDBCExecutionConnection, AutoClo
private int connectionId;
@Setter
private String userName;
private String username;
private final Multimap<String, Connection> cachedConnections = LinkedHashMultimap.create();
......
......@@ -53,7 +53,7 @@ public final class ShowDatabasesBackendHandler implements TextProtocolBackendHan
private Collection getSchemaNames() {
Collection<String> result = new LinkedList<>(ProxySchemaContexts.getInstance().getSchemaNames());
Collection<String> authorizedSchemas = ProxySchemaContexts.getInstance().getSchemaContexts().getAuthentication().getUsers().get(backendConnection.getUserName()).getAuthorizedSchemas();
Collection<String> authorizedSchemas = ProxySchemaContexts.getInstance().getSchemaContexts().getAuthentication().getUsers().get(backendConnection.getUsername()).getAuthorizedSchemas();
if (!authorizedSchemas.isEmpty()) {
result.retainAll(authorizedSchemas);
}
......
......@@ -53,7 +53,7 @@ public final class UseDatabaseBackendHandler implements TextProtocolBackendHandl
}
private boolean isAuthorizedSchema(final String schema) {
Collection<String> authorizedSchemas = ProxySchemaContexts.getInstance().getSchemaContexts().getAuthentication().getUsers().get(backendConnection.getUserName()).getAuthorizedSchemas();
Collection<String> authorizedSchemas = ProxySchemaContexts.getInstance().getSchemaContexts().getAuthentication().getUsers().get(backendConnection.getUsername()).getAuthorizedSchemas();
return authorizedSchemas.isEmpty() || authorizedSchemas.contains(schema);
}
......
......@@ -56,7 +56,7 @@ public final class ShowDatabasesBackendHandlerTest {
@SneakyThrows(ReflectiveOperationException.class)
public void setUp() {
BackendConnection backendConnection = mock(BackendConnection.class);
when(backendConnection.getUserName()).thenReturn("root");
when(backendConnection.getUsername()).thenReturn("root");
showDatabasesBackendHandler = new ShowDatabasesBackendHandler(backendConnection);
Field schemaContexts = ProxySchemaContexts.getInstance().getClass().getDeclaredField("schemaContexts");
schemaContexts.setAccessible(true);
......
......@@ -54,7 +54,7 @@ public class ShowTablesBackendHandlerTest {
@SneakyThrows(ReflectiveOperationException.class)
public void setUp() {
BackendConnection backendConnection = mock(BackendConnection.class);
when(backendConnection.getUserName()).thenReturn("root");
when(backendConnection.getUsername()).thenReturn("root");
tablesBackendHandler = new ShowTablesBackendHandler("show tables", mock(SQLStatement.class), backendConnection);
Map<String, SchemaContext> schemaContextMap = getSchemaContextMap();
when(backendConnection.getSchema()).thenReturn("schema_0");
......
......@@ -58,7 +58,7 @@ public final class UseDatabaseBackendHandlerTest {
@SneakyThrows(ReflectiveOperationException.class)
public void setUp() {
backendConnection = mock(BackendConnection.class);
when(backendConnection.getUserName()).thenReturn("root");
when(backendConnection.getUsername()).thenReturn("root");
Field schemaContexts = ProxySchemaContexts.getInstance().getClass().getDeclaredField("schemaContexts");
schemaContexts.setAccessible(true);
schemaContexts.set(ProxySchemaContexts.getInstance(),
......
......@@ -78,9 +78,9 @@ public final class FrontendChannelInboundHandler extends ChannelInboundHandlerAd
private boolean auth(final ChannelHandlerContext context, final ByteBuf message) {
try (PacketPayload payload = databaseProtocolFrontendEngine.getCodecEngine().createPacketPayload(message)) {
AuthenticationResult authResult = databaseProtocolFrontendEngine.getAuthEngine().auth(context, payload, backendConnection);
AuthenticationResult authResult = databaseProtocolFrontendEngine.getAuthEngine().auth(context, payload);
if (authResult.isFinished()) {
backendConnection.setUserName(authResult.getUsername());
backendConnection.setUsername(authResult.getUsername());
backendConnection.setCurrentSchema(authResult.getDatabase());
}
return authResult.isFinished();
......
......@@ -31,7 +31,6 @@ import org.apache.shardingsphere.db.protocol.mysql.packet.handshake.MySQLHandsha
import org.apache.shardingsphere.db.protocol.mysql.packet.handshake.MySQLHandshakeResponse41Packet;
import org.apache.shardingsphere.db.protocol.mysql.payload.MySQLPacketPayload;
import org.apache.shardingsphere.db.protocol.payload.PacketPayload;
import org.apache.shardingsphere.proxy.backend.communication.jdbc.connection.BackendConnection;
import org.apache.shardingsphere.proxy.backend.schema.ProxySchemaContexts;
import org.apache.shardingsphere.proxy.frontend.ConnectionIdGenerator;
import org.apache.shardingsphere.proxy.frontend.engine.AuthenticationEngine;
......@@ -65,7 +64,7 @@ public final class MySQLAuthenticationEngine implements AuthenticationEngine {
}
@Override
public AuthenticationResult auth(final ChannelHandlerContext context, final PacketPayload payload, final BackendConnection backendConnection) {
public AuthenticationResult auth(final ChannelHandlerContext context, final PacketPayload payload) {
if (MySQLConnectionPhase.AUTH_PHASE_FAST_PATH == connectionPhase) {
currentAuthResult = authPhaseFastPath(context, payload);
if (!currentAuthResult.isFinished()) {
......@@ -75,13 +74,7 @@ public final class MySQLAuthenticationEngine implements AuthenticationEngine {
authenticationMethodMismatch((MySQLPacketPayload) payload);
}
Optional<MySQLServerErrorCode> errorCode = authenticationHandler.login(currentAuthResult.getUsername(), authResponse, currentAuthResult.getDatabase());
if (errorCode.isPresent()) {
context.writeAndFlush(createMySQLErrPacket(errorCode.get(), context));
} else {
backendConnection.setCurrentSchema(currentAuthResult.getDatabase());
backendConnection.setUserName(currentAuthResult.getUsername());
context.writeAndFlush(new MySQLOKPacket(++sequenceId));
}
context.writeAndFlush(errorCode.isPresent() ? createErrorPacket(errorCode.get(), context) : new MySQLOKPacket(++sequenceId));
return AuthenticationResult.finished(currentAuthResult.getUsername(), currentAuthResult.getDatabase());
}
......@@ -96,7 +89,7 @@ public final class MySQLAuthenticationEngine implements AuthenticationEngine {
if (isClientPluginAuth(packet) && !MySQLAuthenticationMethod.SECURE_PASSWORD_AUTHENTICATION.getMethodName().equals(packet.getAuthPluginName())) {
connectionPhase = MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH;
context.writeAndFlush(new MySQLAuthSwitchRequestPacket(++sequenceId, MySQLAuthenticationMethod.SECURE_PASSWORD_AUTHENTICATION.getMethodName(), authenticationHandler.getAuthPluginData()));
return AuthenticationResult.continued();
return AuthenticationResult.continued(packet.getUsername(), packet.getDatabase());
}
return AuthenticationResult.finished(packet.getUsername(), packet.getDatabase());
}
......@@ -111,7 +104,7 @@ public final class MySQLAuthenticationEngine implements AuthenticationEngine {
authResponse = packet.getAuthPluginResponse();
}
private MySQLErrPacket createMySQLErrPacket(final MySQLServerErrorCode errorCode, final ChannelHandlerContext context) {
private MySQLErrPacket createErrorPacket(final MySQLServerErrorCode errorCode, final ChannelHandlerContext context) {
return MySQLServerErrorCode.ER_DBACCESS_DENIED_ERROR == errorCode
? new MySQLErrPacket(++sequenceId, MySQLServerErrorCode.ER_DBACCESS_DENIED_ERROR, currentAuthResult.getUsername(), getHostAddress(context), currentAuthResult.getDatabase())
: new MySQLErrPacket(++sequenceId, MySQLServerErrorCode.ER_ACCESS_DENIED_ERROR, currentAuthResult.getUsername(), getHostAddress(context), getErrorMessage());
......
......@@ -52,7 +52,7 @@ public final class MySQLComInitDbExecutor implements CommandExecutor {
}
private boolean isAuthorizedSchema(final String schema) {
Collection<String> authorizedSchemas = ProxySchemaContexts.getInstance().getSchemaContexts().getAuthentication().getUsers().get(backendConnection.getUserName()).getAuthorizedSchemas();
Collection<String> authorizedSchemas = ProxySchemaContexts.getInstance().getSchemaContexts().getAuthentication().getUsers().get(backendConnection.getUsername()).getAuthorizedSchemas();
return authorizedSchemas.isEmpty() || authorizedSchemas.contains(schema);
}
}
......@@ -34,7 +34,6 @@ import org.apache.shardingsphere.kernel.context.SchemaContexts;
import org.apache.shardingsphere.kernel.context.StandardSchemaContexts;
import org.apache.shardingsphere.kernel.context.runtime.RuntimeContext;
import org.apache.shardingsphere.kernel.context.schema.ShardingSphereSchema;
import org.apache.shardingsphere.proxy.backend.communication.jdbc.connection.BackendConnection;
import org.apache.shardingsphere.proxy.backend.schema.ProxySchemaContexts;
import org.apache.shardingsphere.proxy.frontend.ConnectionIdGenerator;
import org.apache.shardingsphere.proxy.frontend.engine.AuthenticationResult;
......@@ -105,7 +104,7 @@ public final class MySQLProtocolFrontendEngineTest {
ProxyUser proxyUser = new ProxyUser("", Collections.singleton("db1"));
setAuthentication(proxyUser);
when(payload.readStringNul()).thenReturn("root");
AuthenticationResult actual = mysqlProtocolFrontendEngine.getAuthEngine().auth(context, payload, mock(BackendConnection.class));
AuthenticationResult actual = mysqlProtocolFrontendEngine.getAuthEngine().auth(context, payload);
assertThat(actual.getUsername(), is("root"));
assertNull(actual.getDatabase());
assertTrue(actual.isFinished());
......@@ -121,7 +120,7 @@ public final class MySQLProtocolFrontendEngineTest {
when(payload.readStringNulByBytes()).thenReturn("root".getBytes());
when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 3307));
when(context.channel()).thenReturn(channel);
AuthenticationResult actual = mysqlProtocolFrontendEngine.getAuthEngine().auth(context, payload, mock(BackendConnection.class));
AuthenticationResult actual = mysqlProtocolFrontendEngine.getAuthEngine().auth(context, payload);
assertThat(actual.getUsername(), is("root"));
assertNull(actual.getDatabase());
assertTrue(actual.isFinished());
......@@ -138,12 +137,11 @@ public final class MySQLProtocolFrontendEngineTest {
when(payload.readStringNulByBytes()).thenReturn("root".getBytes());
when(context.channel()).thenReturn(channel);
when(channel.remoteAddress()).thenReturn(new InetSocketAddress(InetAddress.getByAddress(new byte[] {(byte) 192, (byte) 168, (byte) 0, (byte) 102}), 3307));
AuthenticationResult actual = mysqlProtocolFrontendEngine.getAuthEngine().auth(context, payload, mock(BackendConnection.class));
AuthenticationResult actual = mysqlProtocolFrontendEngine.getAuthEngine().auth(context, payload);
assertThat(actual.getUsername(), is("root"));
assertNull(actual.getDatabase());
assertTrue(actual.isFinished());
verify(context).writeAndFlush(argThat(
(ArgumentMatcher<MySQLErrPacket>) argument -> "Access denied for user 'root'@'192.168.0.102' (using password: YES)".equals(argument.getErrorMessage())));
verify(context).writeAndFlush(argThat((ArgumentMatcher<MySQLErrPacket>) argument -> "Access denied for user 'root'@'192.168.0.102' (using password: YES)".equals(argument.getErrorMessage())));
}
@SneakyThrows
......@@ -168,8 +166,7 @@ public final class MySQLProtocolFrontendEngineTest {
}
private SchemaContexts getSchemaContexts(final Authentication authentication) {
return new StandardSchemaContexts(getSchemaContextMap(), authentication,
new ConfigurationProperties(new Properties()), new MySQLDatabaseType());
return new StandardSchemaContexts(getSchemaContextMap(), authentication, new ConfigurationProperties(new Properties()), new MySQLDatabaseType());
}
private Map<String, SchemaContext> getSchemaContextMap() {
......
......@@ -32,8 +32,8 @@ import org.apache.shardingsphere.infra.config.properties.ConfigurationProperties
import org.apache.shardingsphere.infra.database.type.dialect.MySQLDatabaseType;
import org.apache.shardingsphere.kernel.context.SchemaContext;
import org.apache.shardingsphere.kernel.context.StandardSchemaContexts;
import org.apache.shardingsphere.proxy.backend.communication.jdbc.connection.BackendConnection;
import org.apache.shardingsphere.proxy.backend.schema.ProxySchemaContexts;
import org.apache.shardingsphere.proxy.frontend.engine.AuthenticationResult;
import org.junit.Before;
import org.junit.Test;
......@@ -85,7 +85,7 @@ public final class MySQLAuthenticationEngineTest {
MySQLPacketPayload payload = mock(MySQLPacketPayload.class);
ChannelHandlerContext channelHandlerContext = mock(ChannelHandlerContext.class);
when(payload.readInt4()).thenReturn(MySQLCapabilityFlag.CLIENT_PLUGIN_AUTH.getValue());
authenticationEngine.auth(channelHandlerContext, payload, mock(BackendConnection.class));
authenticationEngine.auth(channelHandlerContext, payload);
assertThat(getConnectionPhase(), is(MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH));
}
......@@ -95,17 +95,25 @@ public final class MySQLAuthenticationEngineTest {
MySQLPacketPayload payload = mock(MySQLPacketPayload.class);
ChannelHandlerContext channelHandlerContext = mock(ChannelHandlerContext.class);
when(payload.readStringEOFByBytes()).thenReturn(authResponse);
authenticationEngine.auth(channelHandlerContext, payload, mock(BackendConnection.class));
setAuthenticationResult();
authenticationEngine.auth(channelHandlerContext, payload);
assertThat(getAuthResponse(), is(authResponse));
}
@SneakyThrows(ReflectiveOperationException.class)
private void setAuthenticationResult() {
Field field = MySQLAuthenticationEngine.class.getDeclaredField("currentAuthResult");
field.setAccessible(true);
field.set(authenticationEngine, AuthenticationResult.continued("root", "sharding_db"));
}
@Test
public void assertAuthWithLoginFail() throws NoSuchFieldException, IllegalAccessException {
setConnectionPhase(MySQLConnectionPhase.AUTH_PHASE_FAST_PATH);
ChannelHandlerContext context = getContext();
setSchemas();
when(authenticationHandler.login(anyString(), any(), anyString())).thenReturn(Optional.of(MySQLServerErrorCode.ER_ACCESS_DENIED_ERROR));
authenticationEngine.auth(context, getPayload("root", "sharding_db", authResponse), mock(BackendConnection.class));
authenticationEngine.auth(context, getPayload("root", "sharding_db", authResponse));
verify(context).writeAndFlush(any(MySQLErrPacket.class));
}
......@@ -114,7 +122,7 @@ public final class MySQLAuthenticationEngineTest {
ChannelHandlerContext context = getContext();
setSchemas();
setConnectionPhase(MySQLConnectionPhase.AUTH_PHASE_FAST_PATH);
authenticationEngine.auth(context, getPayload("root", "ABSENT DATABASE", authResponse), mock(BackendConnection.class));
authenticationEngine.auth(context, getPayload("root", "ABSENT DATABASE", authResponse));
verify(context).writeAndFlush(any(MySQLErrPacket.class));
}
......@@ -124,7 +132,7 @@ public final class MySQLAuthenticationEngineTest {
ChannelHandlerContext context = getContext();
when(authenticationHandler.login(anyString(), any(), anyString())).thenReturn(Optional.empty());
setSchemas();
authenticationEngine.auth(context, getPayload("root", "sharding_db", authResponse), mock(BackendConnection.class));
authenticationEngine.auth(context, getPayload("root", "sharding_db", authResponse));
verify(context).writeAndFlush(any(MySQLOKPacket.class));
}
......@@ -136,10 +144,10 @@ public final class MySQLAuthenticationEngineTest {
new Authentication(), new ConfigurationProperties(new Properties()), new MySQLDatabaseType()));
}
private MySQLPacketPayload getPayload(final String userName, final String database, final byte[] authResponse) {
private MySQLPacketPayload getPayload(final String username, final String database, final byte[] authResponse) {
MySQLPacketPayload result = mock(MySQLPacketPayload.class);
when(result.readInt4()).thenReturn(MySQLCapabilityFlag.CLIENT_CONNECT_WITH_DB.getValue());
when(result.readStringNul()).thenReturn(userName).thenReturn(database);
when(result.readStringNul()).thenReturn(username).thenReturn(database);
when(result.readStringNulByBytes()).thenReturn(authResponse);
return result;
}
......
......@@ -32,7 +32,6 @@ import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.Postgre
import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLRandomGenerator;
import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLSSLNegativePacket;
import org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
import org.apache.shardingsphere.proxy.backend.communication.jdbc.connection.BackendConnection;
import org.apache.shardingsphere.proxy.backend.schema.ProxySchemaContexts;
import org.apache.shardingsphere.proxy.frontend.ConnectionIdGenerator;
import org.apache.shardingsphere.proxy.frontend.engine.AuthenticationEngine;
......@@ -67,7 +66,7 @@ public final class PostgreSQLAuthenticationEngine implements AuthenticationEngin
}
@Override
public AuthenticationResult auth(final ChannelHandlerContext context, final PacketPayload payload, final BackendConnection backendConnection) {
public AuthenticationResult auth(final ChannelHandlerContext context, final PacketPayload payload) {
if (SSL_REQUEST_PAYLOAD_LENGTH == payload.getByteBuf().markReaderIndex().readInt() && SSL_REQUEST_CODE == payload.getByteBuf().readInt()) {
context.writeAndFlush(new PostgreSQLSSLNegativePacket());
return AuthenticationResult.continued();
......@@ -78,41 +77,35 @@ public final class PostgreSQLAuthenticationEngine implements AuthenticationEngin
startupMessageReceived.set(true);
String databaseName = comStartupPacket.getParametersMap().get(DATABASE_NAME_KEYWORD);
if (!Strings.isNullOrEmpty(databaseName) && !ProxySchemaContexts.getInstance().schemaExists(databaseName)) {
PostgreSQLErrorResponsePacket responsePacket = createPostgreSQLErrorResponsePacket(PostgreSQLErrorCode.INVALID_CATALOG_NAME,
String.format("database \"%s\" does not exist", databaseName));
PostgreSQLErrorResponsePacket responsePacket = createErrorPacket(PostgreSQLErrorCode.INVALID_CATALOG_NAME, String.format("database \"%s\" does not exist", databaseName));
context.writeAndFlush(responsePacket);
context.close();
return AuthenticationResult.continued();
}
backendConnection.setCurrentSchema(databaseName);
String userName = comStartupPacket.getParametersMap().get(USER_NAME_KEYWORD);
if (null == userName || userName.isEmpty()) {
PostgreSQLErrorResponsePacket responsePacket = createPostgreSQLErrorResponsePacket(PostgreSQLErrorCode.SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
"user not set in StartupMessage");
String username = comStartupPacket.getParametersMap().get(USER_NAME_KEYWORD);
if (null == username || username.isEmpty()) {
PostgreSQLErrorResponsePacket responsePacket = createErrorPacket(PostgreSQLErrorCode.SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION, "user not set in StartupMessage");
context.writeAndFlush(responsePacket);
context.close();
return AuthenticationResult.continued();
}
backendConnection.setUserName(userName);
md5Salt = PostgreSQLRandomGenerator.getInstance().generateRandomBytes(4);
context.writeAndFlush(new PostgreSQLAuthenticationMD5PasswordPacket(md5Salt));
return AuthenticationResult.continued(userName, databaseName);
return AuthenticationResult.continued(username, databaseName);
} else {
char messageType = (char) ((PostgreSQLPacketPayload) payload).readInt1();
if ('p' != messageType) {
PostgreSQLErrorResponsePacket responsePacket = createPostgreSQLErrorResponsePacket(PostgreSQLErrorCode.SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
"PasswordMessage is expected, message type 'p', but not '" + messageType + "'");
PostgreSQLErrorResponsePacket responsePacket = createErrorPacket(
PostgreSQLErrorCode.SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION, String.format("PasswordMessage is expected, message type 'p', but not '%s'", messageType));
context.writeAndFlush(responsePacket);
context.close();
currentAuthResult = AuthenticationResult.continued();
return currentAuthResult;
}
PostgreSQLPasswordMessagePacket passwordMessagePacket = new PostgreSQLPasswordMessagePacket((PostgreSQLPacketPayload) payload);
PostgreSQLLoginResult loginResult = PostgreSQLAuthenticationHandler.loginWithMd5Password(
backendConnection.getUserName(), backendConnection.getSchema(), md5Salt, passwordMessagePacket);
PostgreSQLLoginResult loginResult = PostgreSQLAuthenticationHandler.loginWithMd5Password(currentAuthResult.getUsername(), currentAuthResult.getDatabase(), md5Salt, passwordMessagePacket);
if (PostgreSQLErrorCode.SUCCESSFUL_COMPLETION != loginResult.getErrorCode()) {
PostgreSQLErrorResponsePacket responsePacket = createPostgreSQLErrorResponsePacket(loginResult.getErrorCode(),
loginResult.getErrorMessage());
PostgreSQLErrorResponsePacket responsePacket = createErrorPacket(loginResult.getErrorCode(), loginResult.getErrorMessage());
context.writeAndFlush(responsePacket);
context.close();
return AuthenticationResult.continued();
......@@ -128,7 +121,7 @@ public final class PostgreSQLAuthenticationEngine implements AuthenticationEngin
}
}
private PostgreSQLErrorResponsePacket createPostgreSQLErrorResponsePacket(final PostgreSQLErrorCode errorCode, final String errorMessage) {
private PostgreSQLErrorResponsePacket createErrorPacket(final PostgreSQLErrorCode errorCode, final String errorMessage) {
PostgreSQLErrorResponsePacket result = new PostgreSQLErrorResponsePacket();
result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_SEVERITY, "FATAL");
result.addField(PostgreSQLErrorResponsePacket.FIELD_TYPE_CODE, errorCode.getErrorCode());
......
......@@ -19,7 +19,6 @@ package org.apache.shardingsphere.proxy.frontend.engine;
import io.netty.channel.ChannelHandlerContext;
import org.apache.shardingsphere.db.protocol.payload.PacketPayload;
import org.apache.shardingsphere.proxy.backend.communication.jdbc.connection.BackendConnection;
/**
* Authentication engine.
......@@ -39,8 +38,7 @@ public interface AuthenticationEngine {
*
* @param context channel handler context
* @param payload packet payload
* @param backendConnection backend connection
* @return authentication result
*/
AuthenticationResult auth(ChannelHandlerContext context, PacketPayload payload, BackendConnection backendConnection);
AuthenticationResult auth(ChannelHandlerContext context, PacketPayload payload);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册